This is an automated email from the ASF dual-hosted git repository.
dheres pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new d11820aa2 refactor count_distinct to not to have update and merge
(#5408)
d11820aa2 is described below
commit d11820aa256284f9a817c7b699a548f9c3e1c399
Author: Alex Huang <[email protected]>
AuthorDate: Fri Mar 3 09:07:09 2023 +0100
refactor count_distinct to not to have update and merge (#5408)
---
datafusion/physical-expr/src/aggregate/build_in.rs | 5 +-
.../physical-expr/src/aggregate/count_distinct.rs | 430 ++++++---------------
datafusion/proto/src/physical_plan/mod.rs | 5 +-
3 files changed, 117 insertions(+), 323 deletions(-)
diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs
b/datafusion/physical-expr/src/aggregate/build_in.rs
index 5e609f125..b3dbef7df 100644
--- a/datafusion/physical-expr/src/aggregate/build_in.rs
+++ b/datafusion/physical-expr/src/aggregate/build_in.rs
@@ -58,10 +58,9 @@ pub fn create_aggregate_expr(
return_type,
)),
(AggregateFunction::Count, true) =>
Arc::new(expressions::DistinctCount::new(
- input_phy_types,
- input_phy_exprs,
+ input_phy_types[0].clone(),
+ input_phy_exprs[0].clone(),
name,
- return_type,
)),
(AggregateFunction::Grouping, _) =>
Arc::new(expressions::Grouping::new(
input_phy_exprs[0].clone(),
diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs
b/datafusion/physical-expr/src/aggregate/count_distinct.rs
index cdfc4f46a..df4a9ab7b 100644
--- a/datafusion/physical-expr/src/aggregate/count_distinct.rs
+++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs
@@ -30,37 +30,30 @@ use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::Accumulator;
-#[derive(Debug, PartialEq, Eq, Hash, Clone)]
-struct DistinctScalarValues(Vec<ScalarValue>);
+type DistinctScalarValues = ScalarValue;
/// Expression for a COUNT(DISTINCT) aggregation.
#[derive(Debug)]
pub struct DistinctCount {
/// Column name
name: String,
- /// The DataType for the final count
- data_type: DataType,
/// The DataType used to hold the state for each input
- state_data_types: Vec<DataType>,
+ state_data_type: DataType,
/// The input arguments
- exprs: Vec<Arc<dyn PhysicalExpr>>,
+ expr: Arc<dyn PhysicalExpr>,
}
impl DistinctCount {
/// Create a new COUNT(DISTINCT) aggregate function.
pub fn new(
- input_data_types: Vec<DataType>,
- exprs: Vec<Arc<dyn PhysicalExpr>>,
+ input_data_type: DataType,
+ expr: Arc<dyn PhysicalExpr>,
name: String,
- data_type: DataType,
) -> Self {
- let state_data_types = input_data_types;
-
Self {
name,
- data_type,
- state_data_types,
- exprs,
+ state_data_type: input_data_type,
+ expr,
}
}
}
@@ -72,36 +65,29 @@ impl AggregateExpr for DistinctCount {
}
fn field(&self) -> Result<Field> {
- Ok(Field::new(&self.name, self.data_type.clone(), true))
+ Ok(Field::new(&self.name, DataType::Int64, true))
}
fn state_fields(&self) -> Result<Vec<Field>> {
- Ok(self
- .state_data_types
- .iter()
- .map(|state_data_type| {
- Field::new(
- format_state_name(&self.name, "count distinct"),
- DataType::List(Box::new(Field::new(
- "item",
- state_data_type.clone(),
- true,
- ))),
- false,
- )
- })
- .collect::<Vec<_>>())
+ Ok(vec![Field::new(
+ format_state_name(&self.name, "count distinct"),
+ DataType::List(Box::new(Field::new(
+ "item",
+ self.state_data_type.clone(),
+ true,
+ ))),
+ false,
+ )])
}
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
- self.exprs.clone()
+ vec![self.expr.clone()]
}
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(DistinctCountAccumulator {
values: HashSet::default(),
- state_data_types: self.state_data_types.clone(),
- count_data_type: self.data_type.clone(),
+ state_data_type: self.state_data_type.clone(),
}))
}
@@ -113,43 +99,10 @@ impl AggregateExpr for DistinctCount {
#[derive(Debug)]
struct DistinctCountAccumulator {
values: HashSet<DistinctScalarValues, RandomState>,
- state_data_types: Vec<DataType>,
- count_data_type: DataType,
+ state_data_type: DataType,
}
-impl DistinctCountAccumulator {
- fn update(&mut self, values: &[ScalarValue]) -> Result<()> {
- // If a row has a NULL, it is not included in the final count.
- if !values.iter().any(|v| v.is_null()) {
- self.values.insert(DistinctScalarValues(values.to_vec()));
- }
-
- Ok(())
- }
-
- fn merge(&mut self, states: &[ScalarValue]) -> Result<()> {
- if states.is_empty() {
- return Ok(());
- }
-
- let col_values = states
- .iter()
- .map(|state| match state {
- ScalarValue::List(Some(values), _) => Ok(values),
- _ => Err(DataFusionError::Internal(format!(
- "Unexpected accumulator state {state:?}"
- ))),
- })
- .collect::<Result<Vec<_>>>()?;
-
- (0..col_values[0].len()).try_for_each(|row_index| {
- let row_values = col_values
- .iter()
- .map(|col| col[row_index].clone())
- .collect::<Vec<_>>();
- self.update(&row_values)
- })
- }
+impl DistinctCountAccumulator {
// calculating the size for fixed length values, taking first batch size *
number of batches
// This method is faster than .full_size(), however it is not suitable for
variable length values like strings or complex types
fn fixed_size(&self) -> usize {
@@ -159,118 +112,80 @@ impl DistinctCountAccumulator {
.values
.iter()
.next()
- .map(|vals| {
- (ScalarValue::size_of_vec(&vals.0) -
std::mem::size_of_val(&vals.0))
- * self.values.capacity()
- })
+ .map(|vals| ScalarValue::size(vals) -
std::mem::size_of_val(&vals))
.unwrap_or(0)
}
-
- // calculates the size as accurate as possible, call to this method is
expensive
- fn full_size(&self) -> usize {
- std::mem::size_of_val(self)
- + (std::mem::size_of::<DistinctScalarValues>() *
self.values.capacity())
- + self
- .values
- .iter()
- .map(|vals| {
- ScalarValue::size_of_vec(&vals.0) -
std::mem::size_of_val(&vals.0)
- })
- .sum::<usize>()
- + (std::mem::size_of::<DataType>() *
self.state_data_types.capacity())
- + self
- .state_data_types
- .iter()
- .map(|dt| dt.size() - std::mem::size_of_val(dt))
- .sum::<usize>()
- + self.count_data_type.size()
- - std::mem::size_of_val(&self.count_data_type)
- }
}
impl Accumulator for DistinctCountAccumulator {
+ fn state(&self) -> Result<Vec<ScalarValue>> {
+ let mut cols_out =
+ ScalarValue::new_list(Some(Vec::new()),
self.state_data_type.clone());
+ self.values
+ .iter()
+ .enumerate()
+ .for_each(|(_, distinct_values)| {
+ if let ScalarValue::List(Some(ref mut v), _) = cols_out {
+ v.push(distinct_values.clone());
+ }
+ });
+ Ok(vec![cols_out])
+ }
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}
- (0..values[0].len()).try_for_each(|index| {
- let v = values
- .iter()
- .map(|array| ScalarValue::try_from_array(array, index))
- .collect::<Result<Vec<_>>>()?;
- self.update(&v)
+ let arr = &values[0];
+ (0..arr.len()).try_for_each(|index| {
+ if !arr.is_null(index) {
+ let scalar = ScalarValue::try_from_array(arr, index)?;
+ self.values.insert(scalar);
+ }
+ Ok(())
})
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
}
- (0..states[0].len()).try_for_each(|index| {
- let v = states
- .iter()
- .map(|array| ScalarValue::try_from_array(array, index))
- .collect::<Result<Vec<_>>>()?;
- self.merge(&v)
+ let arr = &states[0];
+ (0..arr.len()).try_for_each(|index| {
+ let scalar = ScalarValue::try_from_array(arr, index)?;
+
+ if let ScalarValue::List(Some(scalar), _) = scalar {
+ scalar.iter().for_each(|scalar| {
+ if !ScalarValue::is_null(scalar) {
+ self.values.insert(scalar.clone());
+ }
+ });
+ } else {
+ return Err(DataFusionError::Internal(
+ "Unexpected accumulator state".into(),
+ ));
+ }
+ Ok(())
})
}
- fn state(&self) -> Result<Vec<ScalarValue>> {
- let mut cols_out = self
- .state_data_types
- .iter()
- .map(|state_data_type| {
- ScalarValue::new_list(Some(Vec::new()),
state_data_type.clone())
- })
- .collect::<Vec<_>>();
-
- let mut cols_vec = cols_out
- .iter_mut()
- .map(|c| match c {
- ScalarValue::List(Some(ref mut v), _) => Ok(v),
- t => Err(DataFusionError::Internal(format!(
- "cols_out should only consist of ScalarValue::List. {t:?}
is found"
- ))),
- })
- .collect::<Result<Vec<_>>>()?;
-
- self.values.iter().for_each(|distinct_values| {
- distinct_values.0.iter().enumerate().for_each(
- |(col_index, distinct_value)| {
- cols_vec[col_index].push(distinct_value.clone());
- },
- )
- });
-
- Ok(cols_out.into_iter().collect())
- }
fn evaluate(&self) -> Result<ScalarValue> {
- match &self.count_data_type {
- DataType::Int64 => Ok(ScalarValue::Int64(Some(self.values.len() as
i64))),
- t => Err(DataFusionError::Internal(format!(
- "Invalid data type {t:?} for count distinct aggregation"
- ))),
- }
+ Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
}
fn size(&self) -> usize {
- if self.count_data_type.is_primitive() {
- self.fixed_size()
- } else {
- self.full_size()
- }
+ self.fixed_size()
}
}
#[cfg(test)]
mod tests {
+ use crate::expressions::NoOp;
+
use super::*;
use arrow::array::{
ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array,
Int32Array,
Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array,
UInt8Array,
};
- use arrow::array::{Int32Builder, ListBuilder, UInt64Builder};
use arrow::datatypes::DataType;
- use datafusion_common::cast::as_list_array;
macro_rules! state_to_vec {
($LIST:expr, $DATA_TYPE:ident, $PRIM_TY:ty) => {{
@@ -300,31 +215,6 @@ mod tests {
}};
}
- macro_rules! build_list {
- ($LISTS:expr, $BUILDER_TYPE:ident) => {{
- let mut builder =
ListBuilder::new($BUILDER_TYPE::with_capacity(0));
- for list in $LISTS.iter() {
- match list {
- Some(values) => {
- for value in values.iter() {
- match value {
- Some(v) =>
builder.values().append_value((*v).into()),
- None => builder.values().append_null(),
- }
- }
-
- builder.append(true);
- }
- None => {
- builder.append(false);
- }
- }
- }
-
- Arc::new(builder.finish()) as ArrayRef
- }};
- }
-
macro_rules! test_count_distinct_update_batch_numeric {
($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{
let values: Vec<Option<$PRIM_TYPE>> = vec![
@@ -355,28 +245,11 @@ mod tests {
}};
}
- fn collect_states<T: Ord + Clone, S: Ord + Clone>(
- state1: &[Option<T>],
- state2: &[Option<S>],
- ) -> Vec<(Option<T>, Option<S>)> {
- let mut states = state1
- .iter()
- .zip(state2.iter())
- .map(|(l, r)| (l.clone(), r.clone()))
- .collect::<Vec<(Option<T>, Option<S>)>>();
- states.sort();
- states
- }
-
fn run_update_batch(arrays: &[ArrayRef]) -> Result<(Vec<ScalarValue>,
ScalarValue)> {
let agg = DistinctCount::new(
- arrays
- .iter()
- .map(|a| a.data_type().clone())
- .collect::<Vec<_>>(),
- vec![],
+ arrays[0].data_type().clone(),
+ Arc::new(NoOp::new()),
String::from("__col_name__"),
- DataType::Int64,
);
let mut accum = agg.create_accumulator()?;
@@ -390,10 +263,9 @@ mod tests {
rows: &[Vec<ScalarValue>],
) -> Result<(Vec<ScalarValue>, ScalarValue)> {
let agg = DistinctCount::new(
- data_types.to_vec(),
- vec![],
+ data_types[0].clone(),
+ Arc::new(NoOp::new()),
String::from("__col_name__"),
- DataType::Int64,
);
let mut accum = agg.create_accumulator()?;
@@ -416,24 +288,6 @@ mod tests {
Ok((accum.state()?, accum.evaluate()?))
}
- fn run_merge_batch(arrays: &[ArrayRef]) -> Result<(Vec<ScalarValue>,
ScalarValue)> {
- let agg = DistinctCount::new(
- arrays
- .iter()
- .map(|a| as_list_array(a).unwrap())
- .map(|a| a.values().data_type().clone())
- .collect::<Vec<_>>(),
- vec![],
- String::from("__col_name__"),
- DataType::Int64,
- );
-
- let mut accum = agg.create_accumulator()?;
- accum.merge_batch(arrays)?;
-
- Ok((accum.state()?, accum.evaluate()?))
- }
-
// Used trait to create associated constant for f32 and f64
trait SubNormal: 'static {
const SUBNORMAL: Self;
@@ -635,133 +489,75 @@ mod tests {
Ok(())
}
- #[test]
- fn count_distinct_update_batch_multiple_columns() -> Result<()> {
- let array_int8: ArrayRef = Arc::new(Int8Array::from(vec![1, 1, 2]));
- let array_int16: ArrayRef = Arc::new(Int16Array::from(vec![3, 3, 4]));
- let arrays = vec![array_int8, array_int16];
-
- let (states, result) = run_update_batch(&arrays)?;
-
- let state_vec1 = state_to_vec!(&states[0], Int8, i8).unwrap();
- let state_vec2 = state_to_vec!(&states[1], Int16, i16).unwrap();
- let state_pairs = collect_states::<i8, i16>(&state_vec1, &state_vec2);
-
- assert_eq!(states.len(), 2);
- assert_eq!(
- state_pairs,
- vec![(Some(1_i8), Some(3_i16)), (Some(2_i8), Some(4_i16))]
- );
-
- assert_eq!(result, ScalarValue::Int64(Some(2)));
-
- Ok(())
- }
-
#[test]
fn count_distinct_update() -> Result<()> {
let (states, result) = run_update(
- &[DataType::Int32, DataType::UInt64],
+ &[DataType::Int32],
&[
- vec![ScalarValue::Int32(Some(-1)),
ScalarValue::UInt64(Some(5))],
- vec![ScalarValue::Int32(Some(5)),
ScalarValue::UInt64(Some(1))],
- vec![ScalarValue::Int32(Some(-1)),
ScalarValue::UInt64(Some(5))],
- vec![ScalarValue::Int32(Some(5)),
ScalarValue::UInt64(Some(1))],
- vec![ScalarValue::Int32(Some(-1)),
ScalarValue::UInt64(Some(6))],
- vec![ScalarValue::Int32(Some(-1)),
ScalarValue::UInt64(Some(7))],
- vec![ScalarValue::Int32(Some(2)),
ScalarValue::UInt64(Some(7))],
+ vec![ScalarValue::Int32(Some(-1))],
+ vec![ScalarValue::Int32(Some(5))],
+ vec![ScalarValue::Int32(Some(-1))],
+ vec![ScalarValue::Int32(Some(5))],
+ vec![ScalarValue::Int32(Some(-1))],
+ vec![ScalarValue::Int32(Some(-1))],
+ vec![ScalarValue::Int32(Some(2))],
],
)?;
+ assert_eq!(states.len(), 1);
+ assert_eq!(result, ScalarValue::Int64(Some(3)));
- let state_vec1 = state_to_vec!(&states[0], Int32, i32).unwrap();
- let state_vec2 = state_to_vec!(&states[1], UInt64, u64).unwrap();
- let state_pairs = collect_states::<i32, u64>(&state_vec1, &state_vec2);
-
- assert_eq!(states.len(), 2);
- assert_eq!(
- state_pairs,
- vec![
- (Some(-1_i32), Some(5_u64)),
- (Some(-1_i32), Some(6_u64)),
- (Some(-1_i32), Some(7_u64)),
- (Some(2_i32), Some(7_u64)),
- (Some(5_i32), Some(1_u64)),
- ]
- );
- assert_eq!(result, ScalarValue::Int64(Some(5)));
-
+ let (states, result) = run_update(
+ &[DataType::UInt64],
+ &[
+ vec![ScalarValue::UInt64(Some(1))],
+ vec![ScalarValue::UInt64(Some(5))],
+ vec![ScalarValue::UInt64(Some(1))],
+ vec![ScalarValue::UInt64(Some(5))],
+ vec![ScalarValue::UInt64(Some(1))],
+ vec![ScalarValue::UInt64(Some(1))],
+ vec![ScalarValue::UInt64(Some(2))],
+ ],
+ )?;
+ assert_eq!(states.len(), 1);
+ assert_eq!(result, ScalarValue::Int64(Some(3)));
Ok(())
}
#[test]
fn count_distinct_update_with_nulls() -> Result<()> {
let (states, result) = run_update(
- &[DataType::Int32, DataType::UInt64],
+ &[DataType::Int32],
&[
// None of these updates contains a None, so these are
accumulated.
- vec![ScalarValue::Int32(Some(-1)),
ScalarValue::UInt64(Some(5))],
- vec![ScalarValue::Int32(Some(-1)),
ScalarValue::UInt64(Some(5))],
- vec![ScalarValue::Int32(Some(-2)),
ScalarValue::UInt64(Some(5))],
+ vec![ScalarValue::Int32(Some(-1))],
+ vec![ScalarValue::Int32(Some(-1))],
+ vec![ScalarValue::Int32(Some(-2))],
// Each of these updates contains at least one None, so these
// won't be accumulated.
- vec![ScalarValue::Int32(Some(-1)), ScalarValue::UInt64(None)],
- vec![ScalarValue::Int32(None), ScalarValue::UInt64(Some(5))],
- vec![ScalarValue::Int32(None), ScalarValue::UInt64(None)],
+ vec![ScalarValue::Int32(Some(-1))],
+ vec![ScalarValue::Int32(None)],
+ vec![ScalarValue::Int32(None)],
],
)?;
-
- let state_vec1 = state_to_vec!(&states[0], Int32, i32).unwrap();
- let state_vec2 = state_to_vec!(&states[1], UInt64, u64).unwrap();
- let state_pairs = collect_states::<i32, u64>(&state_vec1, &state_vec2);
-
- assert_eq!(states.len(), 2);
- assert_eq!(
- state_pairs,
- vec![(Some(-2_i32), Some(5_u64)), (Some(-1_i32), Some(5_u64))]
- );
-
+ assert_eq!(states.len(), 1);
assert_eq!(result, ScalarValue::Int64(Some(2)));
- Ok(())
- }
-
- #[test]
- fn count_distinct_merge_batch() -> Result<()> {
- let state_in1 = build_list!(
- vec![
- Some(vec![Some(-1_i32), Some(-1_i32), Some(-2_i32),
Some(-2_i32)]),
- Some(vec![Some(-2_i32), Some(-3_i32)]),
- ],
- Int32Builder
- );
-
- let state_in2 = build_list!(
- vec![
- Some(vec![Some(5_u64), Some(6_u64), Some(5_u64), Some(7_u64)]),
- Some(vec![Some(5_u64), Some(7_u64)]),
+ let (states, result) = run_update(
+ &[DataType::UInt64],
+ &[
+ // None of these updates contains a None, so these are
accumulated.
+ vec![ScalarValue::UInt64(Some(1))],
+ vec![ScalarValue::UInt64(Some(1))],
+ vec![ScalarValue::UInt64(Some(2))],
+ // Each of these updates contains at least one None, so these
+ // won't be accumulated.
+ vec![ScalarValue::UInt64(Some(1))],
+ vec![ScalarValue::UInt64(None)],
+ vec![ScalarValue::UInt64(None)],
],
- UInt64Builder
- );
-
- let (states, result) = run_merge_batch(&[state_in1, state_in2])?;
-
- let state_out_vec1 = state_to_vec!(&states[0], Int32, i32).unwrap();
- let state_out_vec2 = state_to_vec!(&states[1], UInt64, u64).unwrap();
- let state_pairs = collect_states::<i32, u64>(&state_out_vec1,
&state_out_vec2);
-
- assert_eq!(
- state_pairs,
- vec![
- (Some(-3_i32), Some(7_u64)),
- (Some(-2_i32), Some(5_u64)),
- (Some(-2_i32), Some(7_u64)),
- (Some(-1_i32), Some(5_u64)),
- (Some(-1_i32), Some(6_u64)),
- ]
- );
-
- assert_eq!(result, ScalarValue::Int64(Some(5)));
-
+ )?;
+ assert_eq!(states.len(), 1);
+ assert_eq!(result, ScalarValue::Int64(Some(2)));
Ok(())
}
}
diff --git a/datafusion/proto/src/physical_plan/mod.rs
b/datafusion/proto/src/physical_plan/mod.rs
index 8c2ce822f..b9b23a3ce 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -1596,10 +1596,9 @@ mod roundtrip_tests {
let schema = Arc::new(Schema::new(vec![field_a, field_b]));
let aggregates: Vec<Arc<dyn AggregateExpr>> =
vec![Arc::new(DistinctCount::new(
- vec![DataType::Int64],
- vec![col("b", &schema)?],
- "COUNT(DISTINCT b)".to_string(),
DataType::Int64,
+ col("b", &schema)?,
+ "COUNT(DISTINCT b)".to_string(),
))];
let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =