This is an automated email from the ASF dual-hosted git repository.

dheres pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new d11820aa2 refactor count_distinct to not to have update and merge 
(#5408)
d11820aa2 is described below

commit d11820aa256284f9a817c7b699a548f9c3e1c399
Author: Alex Huang <[email protected]>
AuthorDate: Fri Mar 3 09:07:09 2023 +0100

    refactor count_distinct to not to have update and merge (#5408)
---
 datafusion/physical-expr/src/aggregate/build_in.rs |   5 +-
 .../physical-expr/src/aggregate/count_distinct.rs  | 430 ++++++---------------
 datafusion/proto/src/physical_plan/mod.rs          |   5 +-
 3 files changed, 117 insertions(+), 323 deletions(-)

diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs 
b/datafusion/physical-expr/src/aggregate/build_in.rs
index 5e609f125..b3dbef7df 100644
--- a/datafusion/physical-expr/src/aggregate/build_in.rs
+++ b/datafusion/physical-expr/src/aggregate/build_in.rs
@@ -58,10 +58,9 @@ pub fn create_aggregate_expr(
             return_type,
         )),
         (AggregateFunction::Count, true) => 
Arc::new(expressions::DistinctCount::new(
-            input_phy_types,
-            input_phy_exprs,
+            input_phy_types[0].clone(),
+            input_phy_exprs[0].clone(),
             name,
-            return_type,
         )),
         (AggregateFunction::Grouping, _) => 
Arc::new(expressions::Grouping::new(
             input_phy_exprs[0].clone(),
diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs 
b/datafusion/physical-expr/src/aggregate/count_distinct.rs
index cdfc4f46a..df4a9ab7b 100644
--- a/datafusion/physical-expr/src/aggregate/count_distinct.rs
+++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs
@@ -30,37 +30,30 @@ use datafusion_common::ScalarValue;
 use datafusion_common::{DataFusionError, Result};
 use datafusion_expr::Accumulator;
 
-#[derive(Debug, PartialEq, Eq, Hash, Clone)]
-struct DistinctScalarValues(Vec<ScalarValue>);
+type DistinctScalarValues = ScalarValue;
 
 /// Expression for a COUNT(DISTINCT) aggregation.
 #[derive(Debug)]
 pub struct DistinctCount {
     /// Column name
     name: String,
-    /// The DataType for the final count
-    data_type: DataType,
     /// The DataType used to hold the state for each input
-    state_data_types: Vec<DataType>,
+    state_data_type: DataType,
     /// The input arguments
-    exprs: Vec<Arc<dyn PhysicalExpr>>,
+    expr: Arc<dyn PhysicalExpr>,
 }
 
 impl DistinctCount {
     /// Create a new COUNT(DISTINCT) aggregate function.
     pub fn new(
-        input_data_types: Vec<DataType>,
-        exprs: Vec<Arc<dyn PhysicalExpr>>,
+        input_data_type: DataType,
+        expr: Arc<dyn PhysicalExpr>,
         name: String,
-        data_type: DataType,
     ) -> Self {
-        let state_data_types = input_data_types;
-
         Self {
             name,
-            data_type,
-            state_data_types,
-            exprs,
+            state_data_type: input_data_type,
+            expr,
         }
     }
 }
@@ -72,36 +65,29 @@ impl AggregateExpr for DistinctCount {
     }
 
     fn field(&self) -> Result<Field> {
-        Ok(Field::new(&self.name, self.data_type.clone(), true))
+        Ok(Field::new(&self.name, DataType::Int64, true))
     }
 
     fn state_fields(&self) -> Result<Vec<Field>> {
-        Ok(self
-            .state_data_types
-            .iter()
-            .map(|state_data_type| {
-                Field::new(
-                    format_state_name(&self.name, "count distinct"),
-                    DataType::List(Box::new(Field::new(
-                        "item",
-                        state_data_type.clone(),
-                        true,
-                    ))),
-                    false,
-                )
-            })
-            .collect::<Vec<_>>())
+        Ok(vec![Field::new(
+            format_state_name(&self.name, "count distinct"),
+            DataType::List(Box::new(Field::new(
+                "item",
+                self.state_data_type.clone(),
+                true,
+            ))),
+            false,
+        )])
     }
 
     fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
-        self.exprs.clone()
+        vec![self.expr.clone()]
     }
 
     fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
         Ok(Box::new(DistinctCountAccumulator {
             values: HashSet::default(),
-            state_data_types: self.state_data_types.clone(),
-            count_data_type: self.data_type.clone(),
+            state_data_type: self.state_data_type.clone(),
         }))
     }
 
@@ -113,43 +99,10 @@ impl AggregateExpr for DistinctCount {
 #[derive(Debug)]
 struct DistinctCountAccumulator {
     values: HashSet<DistinctScalarValues, RandomState>,
-    state_data_types: Vec<DataType>,
-    count_data_type: DataType,
+    state_data_type: DataType,
 }
-impl DistinctCountAccumulator {
-    fn update(&mut self, values: &[ScalarValue]) -> Result<()> {
-        // If a row has a NULL, it is not included in the final count.
-        if !values.iter().any(|v| v.is_null()) {
-            self.values.insert(DistinctScalarValues(values.to_vec()));
-        }
-
-        Ok(())
-    }
-
-    fn merge(&mut self, states: &[ScalarValue]) -> Result<()> {
-        if states.is_empty() {
-            return Ok(());
-        }
-
-        let col_values = states
-            .iter()
-            .map(|state| match state {
-                ScalarValue::List(Some(values), _) => Ok(values),
-                _ => Err(DataFusionError::Internal(format!(
-                    "Unexpected accumulator state {state:?}"
-                ))),
-            })
-            .collect::<Result<Vec<_>>>()?;
-
-        (0..col_values[0].len()).try_for_each(|row_index| {
-            let row_values = col_values
-                .iter()
-                .map(|col| col[row_index].clone())
-                .collect::<Vec<_>>();
-            self.update(&row_values)
-        })
-    }
 
+impl DistinctCountAccumulator {
     // calculating the size for fixed length values, taking first batch size * 
number of batches
     // This method is faster than .full_size(), however it is not suitable for 
variable length values like strings or complex types
     fn fixed_size(&self) -> usize {
@@ -159,118 +112,80 @@ impl DistinctCountAccumulator {
                 .values
                 .iter()
                 .next()
-                .map(|vals| {
-                    (ScalarValue::size_of_vec(&vals.0) - 
std::mem::size_of_val(&vals.0))
-                        * self.values.capacity()
-                })
+                .map(|vals| ScalarValue::size(vals) - 
std::mem::size_of_val(&vals))
                 .unwrap_or(0)
     }
-
-    // calculates the size as accurate as possible, call to this method is 
expensive
-    fn full_size(&self) -> usize {
-        std::mem::size_of_val(self)
-            + (std::mem::size_of::<DistinctScalarValues>() * 
self.values.capacity())
-            + self
-                .values
-                .iter()
-                .map(|vals| {
-                    ScalarValue::size_of_vec(&vals.0) - 
std::mem::size_of_val(&vals.0)
-                })
-                .sum::<usize>()
-            + (std::mem::size_of::<DataType>() * 
self.state_data_types.capacity())
-            + self
-                .state_data_types
-                .iter()
-                .map(|dt| dt.size() - std::mem::size_of_val(dt))
-                .sum::<usize>()
-            + self.count_data_type.size()
-            - std::mem::size_of_val(&self.count_data_type)
-    }
 }
 
 impl Accumulator for DistinctCountAccumulator {
+    fn state(&self) -> Result<Vec<ScalarValue>> {
+        let mut cols_out =
+            ScalarValue::new_list(Some(Vec::new()), 
self.state_data_type.clone());
+        self.values
+            .iter()
+            .enumerate()
+            .for_each(|(_, distinct_values)| {
+                if let ScalarValue::List(Some(ref mut v), _) = cols_out {
+                    v.push(distinct_values.clone());
+                }
+            });
+        Ok(vec![cols_out])
+    }
     fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
         if values.is_empty() {
             return Ok(());
         }
-        (0..values[0].len()).try_for_each(|index| {
-            let v = values
-                .iter()
-                .map(|array| ScalarValue::try_from_array(array, index))
-                .collect::<Result<Vec<_>>>()?;
-            self.update(&v)
+        let arr = &values[0];
+        (0..arr.len()).try_for_each(|index| {
+            if !arr.is_null(index) {
+                let scalar = ScalarValue::try_from_array(arr, index)?;
+                self.values.insert(scalar);
+            }
+            Ok(())
         })
     }
     fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
         if states.is_empty() {
             return Ok(());
         }
-        (0..states[0].len()).try_for_each(|index| {
-            let v = states
-                .iter()
-                .map(|array| ScalarValue::try_from_array(array, index))
-                .collect::<Result<Vec<_>>>()?;
-            self.merge(&v)
+        let arr = &states[0];
+        (0..arr.len()).try_for_each(|index| {
+            let scalar = ScalarValue::try_from_array(arr, index)?;
+
+            if let ScalarValue::List(Some(scalar), _) = scalar {
+                scalar.iter().for_each(|scalar| {
+                    if !ScalarValue::is_null(scalar) {
+                        self.values.insert(scalar.clone());
+                    }
+                });
+            } else {
+                return Err(DataFusionError::Internal(
+                    "Unexpected accumulator state".into(),
+                ));
+            }
+            Ok(())
         })
     }
-    fn state(&self) -> Result<Vec<ScalarValue>> {
-        let mut cols_out = self
-            .state_data_types
-            .iter()
-            .map(|state_data_type| {
-                ScalarValue::new_list(Some(Vec::new()), 
state_data_type.clone())
-            })
-            .collect::<Vec<_>>();
-
-        let mut cols_vec = cols_out
-            .iter_mut()
-            .map(|c| match c {
-                ScalarValue::List(Some(ref mut v), _) => Ok(v),
-                t => Err(DataFusionError::Internal(format!(
-                    "cols_out should only consist of ScalarValue::List. {t:?} 
is found"
-                ))),
-            })
-            .collect::<Result<Vec<_>>>()?;
-
-        self.values.iter().for_each(|distinct_values| {
-            distinct_values.0.iter().enumerate().for_each(
-                |(col_index, distinct_value)| {
-                    cols_vec[col_index].push(distinct_value.clone());
-                },
-            )
-        });
-
-        Ok(cols_out.into_iter().collect())
-    }
 
     fn evaluate(&self) -> Result<ScalarValue> {
-        match &self.count_data_type {
-            DataType::Int64 => Ok(ScalarValue::Int64(Some(self.values.len() as 
i64))),
-            t => Err(DataFusionError::Internal(format!(
-                "Invalid data type {t:?} for count distinct aggregation"
-            ))),
-        }
+        Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
     }
 
     fn size(&self) -> usize {
-        if self.count_data_type.is_primitive() {
-            self.fixed_size()
-        } else {
-            self.full_size()
-        }
+        self.fixed_size()
     }
 }
 
 #[cfg(test)]
 mod tests {
+    use crate::expressions::NoOp;
+
     use super::*;
     use arrow::array::{
         ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, 
Int32Array,
         Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, 
UInt8Array,
     };
-    use arrow::array::{Int32Builder, ListBuilder, UInt64Builder};
     use arrow::datatypes::DataType;
-    use datafusion_common::cast::as_list_array;
 
     macro_rules! state_to_vec {
         ($LIST:expr, $DATA_TYPE:ident, $PRIM_TY:ty) => {{
@@ -300,31 +215,6 @@ mod tests {
         }};
     }
 
-    macro_rules! build_list {
-        ($LISTS:expr, $BUILDER_TYPE:ident) => {{
-            let mut builder = 
ListBuilder::new($BUILDER_TYPE::with_capacity(0));
-            for list in $LISTS.iter() {
-                match list {
-                    Some(values) => {
-                        for value in values.iter() {
-                            match value {
-                                Some(v) => 
builder.values().append_value((*v).into()),
-                                None => builder.values().append_null(),
-                            }
-                        }
-
-                        builder.append(true);
-                    }
-                    None => {
-                        builder.append(false);
-                    }
-                }
-            }
-
-            Arc::new(builder.finish()) as ArrayRef
-        }};
-    }
-
     macro_rules! test_count_distinct_update_batch_numeric {
         ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{
             let values: Vec<Option<$PRIM_TYPE>> = vec![
@@ -355,28 +245,11 @@ mod tests {
         }};
     }
 
-    fn collect_states<T: Ord + Clone, S: Ord + Clone>(
-        state1: &[Option<T>],
-        state2: &[Option<S>],
-    ) -> Vec<(Option<T>, Option<S>)> {
-        let mut states = state1
-            .iter()
-            .zip(state2.iter())
-            .map(|(l, r)| (l.clone(), r.clone()))
-            .collect::<Vec<(Option<T>, Option<S>)>>();
-        states.sort();
-        states
-    }
-
     fn run_update_batch(arrays: &[ArrayRef]) -> Result<(Vec<ScalarValue>, 
ScalarValue)> {
         let agg = DistinctCount::new(
-            arrays
-                .iter()
-                .map(|a| a.data_type().clone())
-                .collect::<Vec<_>>(),
-            vec![],
+            arrays[0].data_type().clone(),
+            Arc::new(NoOp::new()),
             String::from("__col_name__"),
-            DataType::Int64,
         );
 
         let mut accum = agg.create_accumulator()?;
@@ -390,10 +263,9 @@ mod tests {
         rows: &[Vec<ScalarValue>],
     ) -> Result<(Vec<ScalarValue>, ScalarValue)> {
         let agg = DistinctCount::new(
-            data_types.to_vec(),
-            vec![],
+            data_types[0].clone(),
+            Arc::new(NoOp::new()),
             String::from("__col_name__"),
-            DataType::Int64,
         );
 
         let mut accum = agg.create_accumulator()?;
@@ -416,24 +288,6 @@ mod tests {
         Ok((accum.state()?, accum.evaluate()?))
     }
 
-    fn run_merge_batch(arrays: &[ArrayRef]) -> Result<(Vec<ScalarValue>, 
ScalarValue)> {
-        let agg = DistinctCount::new(
-            arrays
-                .iter()
-                .map(|a| as_list_array(a).unwrap())
-                .map(|a| a.values().data_type().clone())
-                .collect::<Vec<_>>(),
-            vec![],
-            String::from("__col_name__"),
-            DataType::Int64,
-        );
-
-        let mut accum = agg.create_accumulator()?;
-        accum.merge_batch(arrays)?;
-
-        Ok((accum.state()?, accum.evaluate()?))
-    }
-
     // Used trait to create associated constant for f32 and f64
     trait SubNormal: 'static {
         const SUBNORMAL: Self;
@@ -635,133 +489,75 @@ mod tests {
         Ok(())
     }
 
-    #[test]
-    fn count_distinct_update_batch_multiple_columns() -> Result<()> {
-        let array_int8: ArrayRef = Arc::new(Int8Array::from(vec![1, 1, 2]));
-        let array_int16: ArrayRef = Arc::new(Int16Array::from(vec![3, 3, 4]));
-        let arrays = vec![array_int8, array_int16];
-
-        let (states, result) = run_update_batch(&arrays)?;
-
-        let state_vec1 = state_to_vec!(&states[0], Int8, i8).unwrap();
-        let state_vec2 = state_to_vec!(&states[1], Int16, i16).unwrap();
-        let state_pairs = collect_states::<i8, i16>(&state_vec1, &state_vec2);
-
-        assert_eq!(states.len(), 2);
-        assert_eq!(
-            state_pairs,
-            vec![(Some(1_i8), Some(3_i16)), (Some(2_i8), Some(4_i16))]
-        );
-
-        assert_eq!(result, ScalarValue::Int64(Some(2)));
-
-        Ok(())
-    }
-
     #[test]
     fn count_distinct_update() -> Result<()> {
         let (states, result) = run_update(
-            &[DataType::Int32, DataType::UInt64],
+            &[DataType::Int32],
             &[
-                vec![ScalarValue::Int32(Some(-1)), 
ScalarValue::UInt64(Some(5))],
-                vec![ScalarValue::Int32(Some(5)), 
ScalarValue::UInt64(Some(1))],
-                vec![ScalarValue::Int32(Some(-1)), 
ScalarValue::UInt64(Some(5))],
-                vec![ScalarValue::Int32(Some(5)), 
ScalarValue::UInt64(Some(1))],
-                vec![ScalarValue::Int32(Some(-1)), 
ScalarValue::UInt64(Some(6))],
-                vec![ScalarValue::Int32(Some(-1)), 
ScalarValue::UInt64(Some(7))],
-                vec![ScalarValue::Int32(Some(2)), 
ScalarValue::UInt64(Some(7))],
+                vec![ScalarValue::Int32(Some(-1))],
+                vec![ScalarValue::Int32(Some(5))],
+                vec![ScalarValue::Int32(Some(-1))],
+                vec![ScalarValue::Int32(Some(5))],
+                vec![ScalarValue::Int32(Some(-1))],
+                vec![ScalarValue::Int32(Some(-1))],
+                vec![ScalarValue::Int32(Some(2))],
             ],
         )?;
+        assert_eq!(states.len(), 1);
+        assert_eq!(result, ScalarValue::Int64(Some(3)));
 
-        let state_vec1 = state_to_vec!(&states[0], Int32, i32).unwrap();
-        let state_vec2 = state_to_vec!(&states[1], UInt64, u64).unwrap();
-        let state_pairs = collect_states::<i32, u64>(&state_vec1, &state_vec2);
-
-        assert_eq!(states.len(), 2);
-        assert_eq!(
-            state_pairs,
-            vec![
-                (Some(-1_i32), Some(5_u64)),
-                (Some(-1_i32), Some(6_u64)),
-                (Some(-1_i32), Some(7_u64)),
-                (Some(2_i32), Some(7_u64)),
-                (Some(5_i32), Some(1_u64)),
-            ]
-        );
-        assert_eq!(result, ScalarValue::Int64(Some(5)));
-
+        let (states, result) = run_update(
+            &[DataType::UInt64],
+            &[
+                vec![ScalarValue::UInt64(Some(1))],
+                vec![ScalarValue::UInt64(Some(5))],
+                vec![ScalarValue::UInt64(Some(1))],
+                vec![ScalarValue::UInt64(Some(5))],
+                vec![ScalarValue::UInt64(Some(1))],
+                vec![ScalarValue::UInt64(Some(1))],
+                vec![ScalarValue::UInt64(Some(2))],
+            ],
+        )?;
+        assert_eq!(states.len(), 1);
+        assert_eq!(result, ScalarValue::Int64(Some(3)));
         Ok(())
     }
 
     #[test]
     fn count_distinct_update_with_nulls() -> Result<()> {
         let (states, result) = run_update(
-            &[DataType::Int32, DataType::UInt64],
+            &[DataType::Int32],
             &[
                 // None of these updates contains a None, so these are 
accumulated.
-                vec![ScalarValue::Int32(Some(-1)), 
ScalarValue::UInt64(Some(5))],
-                vec![ScalarValue::Int32(Some(-1)), 
ScalarValue::UInt64(Some(5))],
-                vec![ScalarValue::Int32(Some(-2)), 
ScalarValue::UInt64(Some(5))],
+                vec![ScalarValue::Int32(Some(-1))],
+                vec![ScalarValue::Int32(Some(-1))],
+                vec![ScalarValue::Int32(Some(-2))],
                 // Each of these updates contains at least one None, so these
                 // won't be accumulated.
-                vec![ScalarValue::Int32(Some(-1)), ScalarValue::UInt64(None)],
-                vec![ScalarValue::Int32(None), ScalarValue::UInt64(Some(5))],
-                vec![ScalarValue::Int32(None), ScalarValue::UInt64(None)],
+                vec![ScalarValue::Int32(Some(-1))],
+                vec![ScalarValue::Int32(None)],
+                vec![ScalarValue::Int32(None)],
             ],
         )?;
-
-        let state_vec1 = state_to_vec!(&states[0], Int32, i32).unwrap();
-        let state_vec2 = state_to_vec!(&states[1], UInt64, u64).unwrap();
-        let state_pairs = collect_states::<i32, u64>(&state_vec1, &state_vec2);
-
-        assert_eq!(states.len(), 2);
-        assert_eq!(
-            state_pairs,
-            vec![(Some(-2_i32), Some(5_u64)), (Some(-1_i32), Some(5_u64))]
-        );
-
+        assert_eq!(states.len(), 1);
         assert_eq!(result, ScalarValue::Int64(Some(2)));
 
-        Ok(())
-    }
-
-    #[test]
-    fn count_distinct_merge_batch() -> Result<()> {
-        let state_in1 = build_list!(
-            vec![
-                Some(vec![Some(-1_i32), Some(-1_i32), Some(-2_i32), 
Some(-2_i32)]),
-                Some(vec![Some(-2_i32), Some(-3_i32)]),
-            ],
-            Int32Builder
-        );
-
-        let state_in2 = build_list!(
-            vec![
-                Some(vec![Some(5_u64), Some(6_u64), Some(5_u64), Some(7_u64)]),
-                Some(vec![Some(5_u64), Some(7_u64)]),
+        let (states, result) = run_update(
+            &[DataType::UInt64],
+            &[
+                // None of these updates contains a None, so these are 
accumulated.
+                vec![ScalarValue::UInt64(Some(1))],
+                vec![ScalarValue::UInt64(Some(1))],
+                vec![ScalarValue::UInt64(Some(2))],
+                // Each of these updates contains at least one None, so these
+                // won't be accumulated.
+                vec![ScalarValue::UInt64(Some(1))],
+                vec![ScalarValue::UInt64(None)],
+                vec![ScalarValue::UInt64(None)],
             ],
-            UInt64Builder
-        );
-
-        let (states, result) = run_merge_batch(&[state_in1, state_in2])?;
-
-        let state_out_vec1 = state_to_vec!(&states[0], Int32, i32).unwrap();
-        let state_out_vec2 = state_to_vec!(&states[1], UInt64, u64).unwrap();
-        let state_pairs = collect_states::<i32, u64>(&state_out_vec1, 
&state_out_vec2);
-
-        assert_eq!(
-            state_pairs,
-            vec![
-                (Some(-3_i32), Some(7_u64)),
-                (Some(-2_i32), Some(5_u64)),
-                (Some(-2_i32), Some(7_u64)),
-                (Some(-1_i32), Some(5_u64)),
-                (Some(-1_i32), Some(6_u64)),
-            ]
-        );
-
-        assert_eq!(result, ScalarValue::Int64(Some(5)));
-
+        )?;
+        assert_eq!(states.len(), 1);
+        assert_eq!(result, ScalarValue::Int64(Some(2)));
         Ok(())
     }
 }
diff --git a/datafusion/proto/src/physical_plan/mod.rs 
b/datafusion/proto/src/physical_plan/mod.rs
index 8c2ce822f..b9b23a3ce 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -1596,10 +1596,9 @@ mod roundtrip_tests {
         let schema = Arc::new(Schema::new(vec![field_a, field_b]));
 
         let aggregates: Vec<Arc<dyn AggregateExpr>> = 
vec![Arc::new(DistinctCount::new(
-            vec![DataType::Int64],
-            vec![col("b", &schema)?],
-            "COUNT(DISTINCT b)".to_string(),
             DataType::Int64,
+            col("b", &schema)?,
+            "COUNT(DISTINCT b)".to_string(),
         ))];
 
         let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =

Reply via email to