This is an automated email from the ASF dual-hosted git repository.

nju_yaho pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 524a3c534a Count agg support multiple expressions (#5908)
524a3c534a is described below

commit 524a3c534a9307e9d29ba5bf8aed2d2a70ae68e6
Author: allenma <[email protected]>
AuthorDate: Mon Apr 10 13:57:10 2023 +0800

    Count agg support multiple expressions (#5908)
    
    * Count agg support multiple expressions
    
    * Address review comments
---
 datafusion/core/tests/sql/aggregates.rs            | 43 +++++++++++
 datafusion/core/tests/sql/errors.rs                |  2 +-
 datafusion/expr/src/aggregate_function.rs          |  4 +-
 datafusion/expr/src/signature.rs                   |  9 +++
 datafusion/expr/src/type_coercion/aggregates.rs    |  8 ++
 datafusion/expr/src/type_coercion/functions.rs     |  3 +
 datafusion/physical-expr/src/aggregate/build_in.rs | 12 +--
 datafusion/physical-expr/src/aggregate/count.rs    | 85 ++++++++++++++++++++--
 8 files changed, 151 insertions(+), 15 deletions(-)

diff --git a/datafusion/core/tests/sql/aggregates.rs 
b/datafusion/core/tests/sql/aggregates.rs
index e7324eed01..2496331a01 100644
--- a/datafusion/core/tests/sql/aggregates.rs
+++ b/datafusion/core/tests/sql/aggregates.rs
@@ -500,6 +500,49 @@ async fn count_aggregated_cube() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn count_multi_expr() -> Result<()> {
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("c1", DataType::Int32, true),
+        Field::new("c2", DataType::Int32, true),
+    ]));
+
+    let data = RecordBatch::try_new(
+        schema.clone(),
+        vec![
+            Arc::new(Int32Array::from(vec![
+                Some(0),
+                None,
+                Some(1),
+                Some(2),
+                None,
+            ])),
+            Arc::new(Int32Array::from(vec![
+                Some(1),
+                Some(1),
+                Some(0),
+                None,
+                None,
+            ])),
+        ],
+    )?;
+
+    let ctx = SessionContext::new();
+    ctx.register_batch("test", data)?;
+    let sql = "SELECT count(c1, c2) FROM test";
+    let actual = execute_to_batches(&ctx, sql).await;
+
+    let expected = vec![
+        "+------------------------+",
+        "| COUNT(test.c1,test.c2) |",
+        "+------------------------+",
+        "| 2                      |",
+        "+------------------------+",
+    ];
+    assert_batches_sorted_eq!(expected, &actual);
+    Ok(())
+}
+
 #[tokio::test]
 async fn simple_avg() -> Result<()> {
     let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
diff --git a/datafusion/core/tests/sql/errors.rs 
b/datafusion/core/tests/sql/errors.rs
index 64966c44e6..f04531417d 100644
--- a/datafusion/core/tests/sql/errors.rs
+++ b/datafusion/core/tests/sql/errors.rs
@@ -58,7 +58,7 @@ async fn test_aggregation_with_bad_arguments() -> Result<()> {
     assert_eq!(
         err,
         DataFusionError::Plan(
-            "The function Count expects 1 arguments, but 0 were 
provided".to_string()
+            "The function Count expects at least one argument".to_string()
         )
         .to_string()
     );
diff --git a/datafusion/expr/src/aggregate_function.rs 
b/datafusion/expr/src/aggregate_function.rs
index b7fb7d47d2..f1d5ea0092 100644
--- a/datafusion/expr/src/aggregate_function.rs
+++ b/datafusion/expr/src/aggregate_function.rs
@@ -165,8 +165,8 @@ pub fn return_type(
 pub fn signature(fun: &AggregateFunction) -> Signature {
     // note: the physical expression must accept the type returned by this 
function or the execution panics.
     match fun {
-        AggregateFunction::Count
-        | AggregateFunction::ApproxDistinct
+        AggregateFunction::Count => 
Signature::variadic_any(Volatility::Immutable),
+        AggregateFunction::ApproxDistinct
         | AggregateFunction::Grouping
         | AggregateFunction::ArrayAgg => Signature::any(1, 
Volatility::Immutable),
         AggregateFunction::Min | AggregateFunction::Max => {
diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs
index 19909cf2fb..3535c885c3 100644
--- a/datafusion/expr/src/signature.rs
+++ b/datafusion/expr/src/signature.rs
@@ -46,6 +46,8 @@ pub enum TypeSignature {
     // A function such as `array` is `VariadicEqual`
     // The first argument decides the type used for coercion
     VariadicEqual,
+    /// arbitrary number of arguments with arbitrary types
+    VariadicAny,
     /// fixed number of arguments of an arbitrary but equal type out of a list 
of valid types
     // A function of one argument of f64 is `Uniform(1, 
vec![DataType::Float64])`
     // A function of one argument of f64 or f32 is `Uniform(1, 
vec![DataType::Float32, DataType::Float64])`
@@ -89,6 +91,13 @@ impl Signature {
             volatility,
         }
     }
+    /// variadic_any - Creates a variadic signature that represents an 
arbitrary number of arguments of any type.
+    pub fn variadic_any(volatility: Volatility) -> Self {
+        Self {
+            type_signature: TypeSignature::VariadicAny,
+            volatility,
+        }
+    }
     /// uniform - Creates a function with a fixed number of arguments of the 
same type, which must be from valid_types.
     pub fn uniform(
         arg_count: usize,
diff --git a/datafusion/expr/src/type_coercion/aggregates.rs 
b/datafusion/expr/src/type_coercion/aggregates.rs
index 3ad197afb6..3d4b9646dc 100644
--- a/datafusion/expr/src/type_coercion/aggregates.rs
+++ b/datafusion/expr/src/type_coercion/aggregates.rs
@@ -263,6 +263,14 @@ fn check_arg_count(
                 )));
             }
         }
+        TypeSignature::VariadicAny => {
+            if input_types.is_empty() {
+                return Err(DataFusionError::Plan(format!(
+                    "The function {:?} expects at least one argument",
+                    agg_fun
+                )));
+            }
+        }
         _ => {
             return Err(DataFusionError::Internal(format!(
                 "Aggregate functions do not support this {signature:?}"
diff --git a/datafusion/expr/src/type_coercion/functions.rs 
b/datafusion/expr/src/type_coercion/functions.rs
index a038fdcc92..d86914325f 100644
--- a/datafusion/expr/src/type_coercion/functions.rs
+++ b/datafusion/expr/src/type_coercion/functions.rs
@@ -78,6 +78,9 @@ fn get_valid_types(
                 .map(|_| current_types[0].clone())
                 .collect()]
         }
+        TypeSignature::VariadicAny => {
+            vec![current_types.to_vec()]
+        }
         TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
         TypeSignature::Any(number) => {
             if current_types.len() != *number {
diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs 
b/datafusion/physical-expr/src/aggregate/build_in.rs
index b3dbef7dfd..b1e03fb5d9 100644
--- a/datafusion/physical-expr/src/aggregate/build_in.rs
+++ b/datafusion/physical-expr/src/aggregate/build_in.rs
@@ -52,11 +52,13 @@ pub fn create_aggregate_expr(
     let input_phy_exprs = input_phy_exprs.to_vec();
 
     Ok(match (fun, distinct) {
-        (AggregateFunction::Count, false) => Arc::new(expressions::Count::new(
-            input_phy_exprs[0].clone(),
-            name,
-            return_type,
-        )),
+        (AggregateFunction::Count, false) => {
+            Arc::new(expressions::Count::new_with_multiple_exprs(
+                input_phy_exprs,
+                name,
+                return_type,
+            ))
+        }
         (AggregateFunction::Count, true) => 
Arc::new(expressions::DistinctCount::new(
             input_phy_types[0].clone(),
             input_phy_exprs[0].clone(),
diff --git a/datafusion/physical-expr/src/aggregate/count.rs 
b/datafusion/physical-expr/src/aggregate/count.rs
index dc77b794a2..03a5c60a94 100644
--- a/datafusion/physical-expr/src/aggregate/count.rs
+++ b/datafusion/physical-expr/src/aggregate/count.rs
@@ -19,14 +19,16 @@
 
 use std::any::Any;
 use std::fmt::Debug;
+use std::ops::BitAnd;
 use std::sync::Arc;
 
 use crate::aggregate::row_accumulator::RowAccumulator;
 use crate::{AggregateExpr, PhysicalExpr};
-use arrow::array::Int64Array;
+use arrow::array::{Array, Int64Array};
 use arrow::compute;
 use arrow::datatypes::DataType;
 use arrow::{array::ArrayRef, datatypes::Field};
+use arrow_buffer::BooleanBuffer;
 use datafusion_common::{downcast_value, ScalarValue};
 use datafusion_common::{DataFusionError, Result};
 use datafusion_expr::Accumulator;
@@ -41,7 +43,7 @@ pub struct Count {
     name: String,
     data_type: DataType,
     nullable: bool,
-    expr: Arc<dyn PhysicalExpr>,
+    exprs: Vec<Arc<dyn PhysicalExpr>>,
 }
 
 impl Count {
@@ -53,11 +55,43 @@ impl Count {
     ) -> Self {
         Self {
             name: name.into(),
-            expr,
+            exprs: vec![expr],
             data_type,
             nullable: true,
         }
     }
+
+    pub fn new_with_multiple_exprs(
+        exprs: Vec<Arc<dyn PhysicalExpr>>,
+        name: impl Into<String>,
+        data_type: DataType,
+    ) -> Self {
+        Self {
+            name: name.into(),
+            exprs,
+            data_type,
+            nullable: true,
+        }
+    }
+}
+
+/// count null values for multiple columns
+/// for each row if one column value is null, then null_count + 1
+fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize {
+    if values.len() > 1 {
+        let result_bool_buf: Option<BooleanBuffer> = values
+            .iter()
+            .map(|a| a.data().nulls())
+            .fold(None, |acc, b| match (acc, b) {
+                (Some(acc), Some(b)) => Some(acc.bitand(b.inner())),
+                (Some(acc), None) => Some(acc),
+                (None, Some(b)) => Some(b.inner().clone()),
+                _ => None,
+            });
+        result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits())
+    } else {
+        values[0].null_count()
+    }
 }
 
 impl AggregateExpr for Count {
@@ -83,7 +117,7 @@ impl AggregateExpr for Count {
     }
 
     fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
-        vec![self.expr.clone()]
+        self.exprs.clone()
     }
 
     fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
@@ -137,13 +171,13 @@ impl Accumulator for CountAccumulator {
 
     fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
         let array = &values[0];
-        self.count += (array.len() - array.null_count()) as i64;
+        self.count += (array.len() - null_count_for_multiple_cols(values)) as 
i64;
         Ok(())
     }
 
     fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
         let array = &values[0];
-        self.count -= (array.len() - array.null_count()) as i64;
+        self.count -= (array.len() - null_count_for_multiple_cols(values)) as 
i64;
         Ok(())
     }
 
@@ -183,7 +217,7 @@ impl RowAccumulator for CountRowAccumulator {
         accessor: &mut RowAccessor,
     ) -> Result<()> {
         let array = &values[0];
-        let delta = (array.len() - array.null_count()) as u64;
+        let delta = (array.len() - null_count_for_multiple_cols(values)) as 
u64;
         accessor.add_u64(self.state_index, delta);
         Ok(())
     }
@@ -270,4 +304,41 @@ mod tests {
             Arc::new(LargeStringArray::from(vec!["a", "bb", "ccc", "dddd", 
"ad"]));
         generic_test_op!(a, DataType::LargeUtf8, Count, 
ScalarValue::from(5i64))
     }
+
+    #[test]
+    fn count_multi_cols() -> Result<()> {
+        let a: ArrayRef = Arc::new(Int32Array::from(vec![
+            Some(1),
+            Some(2),
+            None,
+            None,
+            Some(3),
+            None,
+        ]));
+        let b: ArrayRef = Arc::new(Int32Array::from(vec![
+            Some(1),
+            None,
+            Some(2),
+            None,
+            Some(3),
+            Some(4),
+        ]));
+        let schema = Schema::new(vec![
+            Field::new("a", DataType::Int32, true),
+            Field::new("b", DataType::Int32, true),
+        ]);
+
+        let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a, 
b])?;
+
+        let agg = Arc::new(Count::new_with_multiple_exprs(
+            vec![col("a", &schema)?, col("b", &schema)?],
+            "bla".to_string(),
+            DataType::Int64,
+        ));
+        let actual = aggregate(&batch, agg)?;
+        let expected = ScalarValue::from(2i64);
+
+        assert_eq!(expected, actual);
+        Ok(())
+    }
 }

Reply via email to