This is an automated email from the ASF dual-hosted git repository.
nju_yaho pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 524a3c534a Count agg support multiple expressions (#5908)
524a3c534a is described below
commit 524a3c534a9307e9d29ba5bf8aed2d2a70ae68e6
Author: allenma <[email protected]>
AuthorDate: Mon Apr 10 13:57:10 2023 +0800
Count agg support multiple expressions (#5908)
* Count agg support multiple expressions
* Address review comments
---
datafusion/core/tests/sql/aggregates.rs | 43 +++++++++++
datafusion/core/tests/sql/errors.rs | 2 +-
datafusion/expr/src/aggregate_function.rs | 4 +-
datafusion/expr/src/signature.rs | 9 +++
datafusion/expr/src/type_coercion/aggregates.rs | 8 ++
datafusion/expr/src/type_coercion/functions.rs | 3 +
datafusion/physical-expr/src/aggregate/build_in.rs | 12 +--
datafusion/physical-expr/src/aggregate/count.rs | 85 ++++++++++++++++++++--
8 files changed, 151 insertions(+), 15 deletions(-)
diff --git a/datafusion/core/tests/sql/aggregates.rs
b/datafusion/core/tests/sql/aggregates.rs
index e7324eed01..2496331a01 100644
--- a/datafusion/core/tests/sql/aggregates.rs
+++ b/datafusion/core/tests/sql/aggregates.rs
@@ -500,6 +500,49 @@ async fn count_aggregated_cube() -> Result<()> {
Ok(())
}
+#[tokio::test]
+async fn count_multi_expr() -> Result<()> {
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("c1", DataType::Int32, true),
+ Field::new("c2", DataType::Int32, true),
+ ]));
+
+ let data = RecordBatch::try_new(
+ schema.clone(),
+ vec![
+ Arc::new(Int32Array::from(vec![
+ Some(0),
+ None,
+ Some(1),
+ Some(2),
+ None,
+ ])),
+ Arc::new(Int32Array::from(vec![
+ Some(1),
+ Some(1),
+ Some(0),
+ None,
+ None,
+ ])),
+ ],
+ )?;
+
+ let ctx = SessionContext::new();
+ ctx.register_batch("test", data)?;
+ let sql = "SELECT count(c1, c2) FROM test";
+ let actual = execute_to_batches(&ctx, sql).await;
+
+ let expected = vec![
+ "+------------------------+",
+ "| COUNT(test.c1,test.c2) |",
+ "+------------------------+",
+ "| 2 |",
+ "+------------------------+",
+ ];
+ assert_batches_sorted_eq!(expected, &actual);
+ Ok(())
+}
+
#[tokio::test]
async fn simple_avg() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
diff --git a/datafusion/core/tests/sql/errors.rs
b/datafusion/core/tests/sql/errors.rs
index 64966c44e6..f04531417d 100644
--- a/datafusion/core/tests/sql/errors.rs
+++ b/datafusion/core/tests/sql/errors.rs
@@ -58,7 +58,7 @@ async fn test_aggregation_with_bad_arguments() -> Result<()> {
assert_eq!(
err,
DataFusionError::Plan(
- "The function Count expects 1 arguments, but 0 were
provided".to_string()
+ "The function Count expects at least one argument".to_string()
)
.to_string()
);
diff --git a/datafusion/expr/src/aggregate_function.rs
b/datafusion/expr/src/aggregate_function.rs
index b7fb7d47d2..f1d5ea0092 100644
--- a/datafusion/expr/src/aggregate_function.rs
+++ b/datafusion/expr/src/aggregate_function.rs
@@ -165,8 +165,8 @@ pub fn return_type(
pub fn signature(fun: &AggregateFunction) -> Signature {
// note: the physical expression must accept the type returned by this
function or the execution panics.
match fun {
- AggregateFunction::Count
- | AggregateFunction::ApproxDistinct
+ AggregateFunction::Count =>
Signature::variadic_any(Volatility::Immutable),
+ AggregateFunction::ApproxDistinct
| AggregateFunction::Grouping
| AggregateFunction::ArrayAgg => Signature::any(1,
Volatility::Immutable),
AggregateFunction::Min | AggregateFunction::Max => {
diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs
index 19909cf2fb..3535c885c3 100644
--- a/datafusion/expr/src/signature.rs
+++ b/datafusion/expr/src/signature.rs
@@ -46,6 +46,8 @@ pub enum TypeSignature {
// A function such as `array` is `VariadicEqual`
// The first argument decides the type used for coercion
VariadicEqual,
+ /// arbitrary number of arguments with arbitrary types
+ VariadicAny,
/// fixed number of arguments of an arbitrary but equal type out of a list
of valid types
// A function of one argument of f64 is `Uniform(1,
vec![DataType::Float64])`
// A function of one argument of f64 or f32 is `Uniform(1,
vec![DataType::Float32, DataType::Float64])`
@@ -89,6 +91,13 @@ impl Signature {
volatility,
}
}
+ /// variadic_any - Creates a variadic signature that represents an
arbitrary number of arguments of any type.
+ pub fn variadic_any(volatility: Volatility) -> Self {
+ Self {
+ type_signature: TypeSignature::VariadicAny,
+ volatility,
+ }
+ }
/// uniform - Creates a function with a fixed number of arguments of the
same type, which must be from valid_types.
pub fn uniform(
arg_count: usize,
diff --git a/datafusion/expr/src/type_coercion/aggregates.rs
b/datafusion/expr/src/type_coercion/aggregates.rs
index 3ad197afb6..3d4b9646dc 100644
--- a/datafusion/expr/src/type_coercion/aggregates.rs
+++ b/datafusion/expr/src/type_coercion/aggregates.rs
@@ -263,6 +263,14 @@ fn check_arg_count(
)));
}
}
+ TypeSignature::VariadicAny => {
+ if input_types.is_empty() {
+ return Err(DataFusionError::Plan(format!(
+ "The function {:?} expects at least one argument",
+ agg_fun
+ )));
+ }
+ }
_ => {
return Err(DataFusionError::Internal(format!(
"Aggregate functions do not support this {signature:?}"
diff --git a/datafusion/expr/src/type_coercion/functions.rs
b/datafusion/expr/src/type_coercion/functions.rs
index a038fdcc92..d86914325f 100644
--- a/datafusion/expr/src/type_coercion/functions.rs
+++ b/datafusion/expr/src/type_coercion/functions.rs
@@ -78,6 +78,9 @@ fn get_valid_types(
.map(|_| current_types[0].clone())
.collect()]
}
+ TypeSignature::VariadicAny => {
+ vec![current_types.to_vec()]
+ }
TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
TypeSignature::Any(number) => {
if current_types.len() != *number {
diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs
b/datafusion/physical-expr/src/aggregate/build_in.rs
index b3dbef7dfd..b1e03fb5d9 100644
--- a/datafusion/physical-expr/src/aggregate/build_in.rs
+++ b/datafusion/physical-expr/src/aggregate/build_in.rs
@@ -52,11 +52,13 @@ pub fn create_aggregate_expr(
let input_phy_exprs = input_phy_exprs.to_vec();
Ok(match (fun, distinct) {
- (AggregateFunction::Count, false) => Arc::new(expressions::Count::new(
- input_phy_exprs[0].clone(),
- name,
- return_type,
- )),
+ (AggregateFunction::Count, false) => {
+ Arc::new(expressions::Count::new_with_multiple_exprs(
+ input_phy_exprs,
+ name,
+ return_type,
+ ))
+ }
(AggregateFunction::Count, true) =>
Arc::new(expressions::DistinctCount::new(
input_phy_types[0].clone(),
input_phy_exprs[0].clone(),
diff --git a/datafusion/physical-expr/src/aggregate/count.rs
b/datafusion/physical-expr/src/aggregate/count.rs
index dc77b794a2..03a5c60a94 100644
--- a/datafusion/physical-expr/src/aggregate/count.rs
+++ b/datafusion/physical-expr/src/aggregate/count.rs
@@ -19,14 +19,16 @@
use std::any::Any;
use std::fmt::Debug;
+use std::ops::BitAnd;
use std::sync::Arc;
use crate::aggregate::row_accumulator::RowAccumulator;
use crate::{AggregateExpr, PhysicalExpr};
-use arrow::array::Int64Array;
+use arrow::array::{Array, Int64Array};
use arrow::compute;
use arrow::datatypes::DataType;
use arrow::{array::ArrayRef, datatypes::Field};
+use arrow_buffer::BooleanBuffer;
use datafusion_common::{downcast_value, ScalarValue};
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::Accumulator;
@@ -41,7 +43,7 @@ pub struct Count {
name: String,
data_type: DataType,
nullable: bool,
- expr: Arc<dyn PhysicalExpr>,
+ exprs: Vec<Arc<dyn PhysicalExpr>>,
}
impl Count {
@@ -53,11 +55,43 @@ impl Count {
) -> Self {
Self {
name: name.into(),
- expr,
+ exprs: vec![expr],
data_type,
nullable: true,
}
}
+
+ pub fn new_with_multiple_exprs(
+ exprs: Vec<Arc<dyn PhysicalExpr>>,
+ name: impl Into<String>,
+ data_type: DataType,
+ ) -> Self {
+ Self {
+ name: name.into(),
+ exprs,
+ data_type,
+ nullable: true,
+ }
+ }
+}
+
+/// count null values for multiple columns
+/// for each row if one column value is null, then null_count + 1
+fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize {
+ if values.len() > 1 {
+ let result_bool_buf: Option<BooleanBuffer> = values
+ .iter()
+ .map(|a| a.data().nulls())
+ .fold(None, |acc, b| match (acc, b) {
+ (Some(acc), Some(b)) => Some(acc.bitand(b.inner())),
+ (Some(acc), None) => Some(acc),
+ (None, Some(b)) => Some(b.inner().clone()),
+ _ => None,
+ });
+ result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits())
+ } else {
+ values[0].null_count()
+ }
}
impl AggregateExpr for Count {
@@ -83,7 +117,7 @@ impl AggregateExpr for Count {
}
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
- vec![self.expr.clone()]
+ self.exprs.clone()
}
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
@@ -137,13 +171,13 @@ impl Accumulator for CountAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let array = &values[0];
- self.count += (array.len() - array.null_count()) as i64;
+ self.count += (array.len() - null_count_for_multiple_cols(values)) as
i64;
Ok(())
}
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let array = &values[0];
- self.count -= (array.len() - array.null_count()) as i64;
+ self.count -= (array.len() - null_count_for_multiple_cols(values)) as
i64;
Ok(())
}
@@ -183,7 +217,7 @@ impl RowAccumulator for CountRowAccumulator {
accessor: &mut RowAccessor,
) -> Result<()> {
let array = &values[0];
- let delta = (array.len() - array.null_count()) as u64;
+ let delta = (array.len() - null_count_for_multiple_cols(values)) as
u64;
accessor.add_u64(self.state_index, delta);
Ok(())
}
@@ -270,4 +304,41 @@ mod tests {
Arc::new(LargeStringArray::from(vec!["a", "bb", "ccc", "dddd",
"ad"]));
generic_test_op!(a, DataType::LargeUtf8, Count,
ScalarValue::from(5i64))
}
+
+ #[test]
+ fn count_multi_cols() -> Result<()> {
+ let a: ArrayRef = Arc::new(Int32Array::from(vec![
+ Some(1),
+ Some(2),
+ None,
+ None,
+ Some(3),
+ None,
+ ]));
+ let b: ArrayRef = Arc::new(Int32Array::from(vec![
+ Some(1),
+ None,
+ Some(2),
+ None,
+ Some(3),
+ Some(4),
+ ]));
+ let schema = Schema::new(vec![
+ Field::new("a", DataType::Int32, true),
+ Field::new("b", DataType::Int32, true),
+ ]);
+
+ let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a,
b])?;
+
+ let agg = Arc::new(Count::new_with_multiple_exprs(
+ vec![col("a", &schema)?, col("b", &schema)?],
+ "bla".to_string(),
+ DataType::Int64,
+ ));
+ let actual = aggregate(&batch, agg)?;
+ let expected = ScalarValue::from(2i64);
+
+ assert_eq!(expected, actual);
+ Ok(())
+ }
}