This is an automated email from the ASF dual-hosted git repository.
dheres pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 561d941854 feat: native types in `DistinctCountAccumulator` for
primitive types (#8721)
561d941854 is described below
commit 561d941854e96bed2de3c7a4c6de9afab622a08b
Author: Eduard Karacharov <[email protected]>
AuthorDate: Fri Jan 5 11:32:17 2024 +0200
feat: native types in `DistinctCountAccumulator` for primitive types (#8721)
* DistinctCountGroupsAccumulator
* test coverage
* clippy warnings
* count distinct for primitive types
* revert hashset to std
* fixed accumulator size estimation
---
.../physical-expr/src/aggregate/count_distinct.rs | 298 ++++++++++++++++++++-
.../physical-expr/src/aggregate/sum_distinct.rs | 22 +-
datafusion/physical-expr/src/aggregate/utils.rs | 20 +-
3 files changed, 311 insertions(+), 29 deletions(-)
diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs
b/datafusion/physical-expr/src/aggregate/count_distinct.rs
index c2fd32a96c..f7c13948b2 100644
--- a/datafusion/physical-expr/src/aggregate/count_distinct.rs
+++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs
@@ -15,21 +15,32 @@
// specific language governing permissions and limitations
// under the License.
-use arrow::datatypes::{DataType, Field};
+use arrow::datatypes::{DataType, Field, TimeUnit};
+use arrow_array::types::{
+ ArrowPrimitiveType, Date32Type, Date64Type, Decimal128Type, Decimal256Type,
+ Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type,
Int8Type,
+ Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
Time64NanosecondType,
+ TimestampMicrosecondType, TimestampMillisecondType,
TimestampNanosecondType,
+ TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
+};
+use arrow_array::PrimitiveArray;
use std::any::Any;
+use std::cmp::Eq;
use std::fmt::Debug;
+use std::hash::Hash;
use std::sync::Arc;
use ahash::RandomState;
use arrow::array::{Array, ArrayRef};
use std::collections::HashSet;
-use crate::aggregate::utils::down_cast_any_ref;
+use crate::aggregate::utils::{down_cast_any_ref, Hashable};
use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};
-use datafusion_common::Result;
-use datafusion_common::ScalarValue;
+use datafusion_common::cast::{as_list_array, as_primitive_array};
+use datafusion_common::utils::array_into_list_array;
+use datafusion_common::{Result, ScalarValue};
use datafusion_expr::Accumulator;
type DistinctScalarValues = ScalarValue;
@@ -60,6 +71,18 @@ impl DistinctCount {
}
}
+macro_rules! native_distinct_count_accumulator {
+ ($TYPE:ident) => {{
+ Ok(Box::new(NativeDistinctCountAccumulator::<$TYPE>::new()))
+ }};
+}
+
+macro_rules! float_distinct_count_accumulator {
+ ($TYPE:ident) => {{
+ Ok(Box::new(FloatDistinctCountAccumulator::<$TYPE>::new()))
+ }};
+}
+
impl AggregateExpr for DistinctCount {
/// Return a reference to Any that can be used for downcasting
fn as_any(&self) -> &dyn Any {
@@ -83,10 +106,57 @@ impl AggregateExpr for DistinctCount {
}
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
- Ok(Box::new(DistinctCountAccumulator {
- values: HashSet::default(),
- state_data_type: self.state_data_type.clone(),
- }))
+ use DataType::*;
+ use TimeUnit::*;
+
+ match &self.state_data_type {
+ Int8 => native_distinct_count_accumulator!(Int8Type),
+ Int16 => native_distinct_count_accumulator!(Int16Type),
+ Int32 => native_distinct_count_accumulator!(Int32Type),
+ Int64 => native_distinct_count_accumulator!(Int64Type),
+ UInt8 => native_distinct_count_accumulator!(UInt8Type),
+ UInt16 => native_distinct_count_accumulator!(UInt16Type),
+ UInt32 => native_distinct_count_accumulator!(UInt32Type),
+ UInt64 => native_distinct_count_accumulator!(UInt64Type),
+ Decimal128(_, _) =>
native_distinct_count_accumulator!(Decimal128Type),
+ Decimal256(_, _) =>
native_distinct_count_accumulator!(Decimal256Type),
+
+ Date32 => native_distinct_count_accumulator!(Date32Type),
+ Date64 => native_distinct_count_accumulator!(Date64Type),
+ Time32(Millisecond) => {
+ native_distinct_count_accumulator!(Time32MillisecondType)
+ }
+ Time32(Second) => {
+ native_distinct_count_accumulator!(Time32SecondType)
+ }
+ Time64(Microsecond) => {
+ native_distinct_count_accumulator!(Time64MicrosecondType)
+ }
+ Time64(Nanosecond) => {
+ native_distinct_count_accumulator!(Time64NanosecondType)
+ }
+ Timestamp(Microsecond, _) => {
+ native_distinct_count_accumulator!(TimestampMicrosecondType)
+ }
+ Timestamp(Millisecond, _) => {
+ native_distinct_count_accumulator!(TimestampMillisecondType)
+ }
+ Timestamp(Nanosecond, _) => {
+ native_distinct_count_accumulator!(TimestampNanosecondType)
+ }
+ Timestamp(Second, _) => {
+ native_distinct_count_accumulator!(TimestampSecondType)
+ }
+
+ Float16 => float_distinct_count_accumulator!(Float16Type),
+ Float32 => float_distinct_count_accumulator!(Float32Type),
+ Float64 => float_distinct_count_accumulator!(Float64Type),
+
+ _ => Ok(Box::new(DistinctCountAccumulator {
+ values: HashSet::default(),
+ state_data_type: self.state_data_type.clone(),
+ })),
+ }
}
fn name(&self) -> &str {
@@ -192,6 +262,182 @@ impl Accumulator for DistinctCountAccumulator {
}
}
+#[derive(Debug)]
+struct NativeDistinctCountAccumulator<T>
+where
+ T: ArrowPrimitiveType + Send,
+ T::Native: Eq + Hash,
+{
+ values: HashSet<T::Native, RandomState>,
+}
+
+impl<T> NativeDistinctCountAccumulator<T>
+where
+ T: ArrowPrimitiveType + Send,
+ T::Native: Eq + Hash,
+{
+ fn new() -> Self {
+ Self {
+ values: HashSet::default(),
+ }
+ }
+}
+
+impl<T> Accumulator for NativeDistinctCountAccumulator<T>
+where
+ T: ArrowPrimitiveType + Send + Debug,
+ T::Native: Eq + Hash,
+{
+ fn state(&self) -> Result<Vec<ScalarValue>> {
+ let arr = Arc::new(PrimitiveArray::<T>::from_iter_values(
+ self.values.iter().cloned(),
+ )) as ArrayRef;
+ let list = Arc::new(array_into_list_array(arr)) as ArrayRef;
+ Ok(vec![ScalarValue::List(list)])
+ }
+
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ if values.is_empty() {
+ return Ok(());
+ }
+
+ let arr = as_primitive_array::<T>(&values[0])?;
+ arr.iter().for_each(|value| {
+ if let Some(value) = value {
+ self.values.insert(value);
+ }
+ });
+
+ Ok(())
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ if states.is_empty() {
+ return Ok(());
+ }
+ assert_eq!(
+ states.len(),
+ 1,
+ "count_distinct states must be single array"
+ );
+
+ let arr = as_list_array(&states[0])?;
+ arr.iter().try_for_each(|maybe_list| {
+ if let Some(list) = maybe_list {
+ let list = as_primitive_array::<T>(&list)?;
+ self.values.extend(list.values())
+ };
+ Ok(())
+ })
+ }
+
+ fn evaluate(&self) -> Result<ScalarValue> {
+ Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
+ }
+
+ fn size(&self) -> usize {
+ let estimated_buckets =
(self.values.len().checked_mul(8).unwrap_or(usize::MAX)
+ / 7)
+ .next_power_of_two();
+
+ // Size of accumulator
+ // + size of entry * number of buckets
+ // + 1 byte for each bucket
+ // + fixed size of HashSet
+ std::mem::size_of_val(self)
+ + std::mem::size_of::<T::Native>() * estimated_buckets
+ + estimated_buckets
+ + std::mem::size_of_val(&self.values)
+ }
+}
+
+#[derive(Debug)]
+struct FloatDistinctCountAccumulator<T>
+where
+ T: ArrowPrimitiveType + Send,
+{
+ values: HashSet<Hashable<T::Native>, RandomState>,
+}
+
+impl<T> FloatDistinctCountAccumulator<T>
+where
+ T: ArrowPrimitiveType + Send,
+{
+ fn new() -> Self {
+ Self {
+ values: HashSet::default(),
+ }
+ }
+}
+
+impl<T> Accumulator for FloatDistinctCountAccumulator<T>
+where
+ T: ArrowPrimitiveType + Send + Debug,
+{
+ fn state(&self) -> Result<Vec<ScalarValue>> {
+ let arr = Arc::new(PrimitiveArray::<T>::from_iter_values(
+ self.values.iter().map(|v| v.0),
+ )) as ArrayRef;
+ let list = Arc::new(array_into_list_array(arr)) as ArrayRef;
+ Ok(vec![ScalarValue::List(list)])
+ }
+
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ if values.is_empty() {
+ return Ok(());
+ }
+
+ let arr = as_primitive_array::<T>(&values[0])?;
+ arr.iter().for_each(|value| {
+ if let Some(value) = value {
+ self.values.insert(Hashable(value));
+ }
+ });
+
+ Ok(())
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ if states.is_empty() {
+ return Ok(());
+ }
+ assert_eq!(
+ states.len(),
+ 1,
+ "count_distinct states must be single array"
+ );
+
+ let arr = as_list_array(&states[0])?;
+ arr.iter().try_for_each(|maybe_list| {
+ if let Some(list) = maybe_list {
+ let list = as_primitive_array::<T>(&list)?;
+ self.values
+ .extend(list.values().iter().map(|v| Hashable(*v)));
+ };
+ Ok(())
+ })
+ }
+
+ fn evaluate(&self) -> Result<ScalarValue> {
+ Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
+ }
+
+ fn size(&self) -> usize {
+ let estimated_buckets =
(self.values.len().checked_mul(8).unwrap_or(usize::MAX)
+ / 7)
+ .next_power_of_two();
+
+ // Size of accumulator
+ // + size of entry * number of buckets
+ // + 1 byte for each bucket
+ // + fixed size of HashSet
+ std::mem::size_of_val(self)
+ + std::mem::size_of::<T::Native>() * estimated_buckets
+ + estimated_buckets
+ + std::mem::size_of_val(&self.values)
+ }
+}
+
#[cfg(test)]
mod tests {
use crate::expressions::NoOp;
@@ -206,6 +452,8 @@ mod tests {
Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
UInt16Type,
UInt32Type, UInt64Type, UInt8Type,
};
+ use arrow_array::Decimal256Array;
+ use arrow_buffer::i256;
use datafusion_common::cast::{as_boolean_array, as_list_array,
as_primitive_array};
use datafusion_common::internal_err;
use datafusion_common::DataFusionError;
@@ -367,6 +615,35 @@ mod tests {
}};
}
+ macro_rules! test_count_distinct_update_batch_bigint {
+ ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{
+ let values: Vec<Option<$PRIM_TYPE>> = vec![
+ Some(i256::from(1)),
+ Some(i256::from(1)),
+ None,
+ Some(i256::from(3)),
+ Some(i256::from(2)),
+ None,
+ Some(i256::from(2)),
+ Some(i256::from(3)),
+ Some(i256::from(1)),
+ ];
+
+ let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef];
+
+ let (states, result) = run_update_batch(&arrays)?;
+
+ let mut state_vec = state_to_vec_primitive!(&states[0],
$DATA_TYPE);
+ state_vec.sort();
+
+ assert_eq!(states.len(), 1);
+ assert_eq!(state_vec, vec![i256::from(1), i256::from(2),
i256::from(3)]);
+ assert_eq!(result, ScalarValue::Int64(Some(3)));
+
+ Ok(())
+ }};
+ }
+
#[test]
fn count_distinct_update_batch_i8() -> Result<()> {
test_count_distinct_update_batch_numeric!(Int8Array, Int8Type, i8)
@@ -417,6 +694,11 @@ mod tests {
test_count_distinct_update_batch_floating_point!(Float64Array,
Float64Type, f64)
}
+ #[test]
+ fn count_distinct_update_batch_i256() -> Result<()> {
+ test_count_distinct_update_batch_bigint!(Decimal256Array,
Decimal256Type, i256)
+ }
+
#[test]
fn count_distinct_update_batch_boolean() -> Result<()> {
let get_count = |data: BooleanArray| -> Result<(Vec<bool>, i64)> {
diff --git a/datafusion/physical-expr/src/aggregate/sum_distinct.rs
b/datafusion/physical-expr/src/aggregate/sum_distinct.rs
index 0cf4a90ab8..6dbb392246 100644
--- a/datafusion/physical-expr/src/aggregate/sum_distinct.rs
+++ b/datafusion/physical-expr/src/aggregate/sum_distinct.rs
@@ -25,11 +25,11 @@ use arrow::array::{Array, ArrayRef};
use arrow_array::cast::AsArray;
use arrow_array::types::*;
use arrow_array::{ArrowNativeTypeOp, ArrowPrimitiveType};
-use arrow_buffer::{ArrowNativeType, ToByteSlice};
+use arrow_buffer::ArrowNativeType;
use std::collections::HashSet;
use crate::aggregate::sum::downcast_sum;
-use crate::aggregate::utils::down_cast_any_ref;
+use crate::aggregate::utils::{down_cast_any_ref, Hashable};
use crate::{AggregateExpr, PhysicalExpr};
use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue};
use datafusion_expr::type_coercion::aggregates::sum_return_type;
@@ -119,24 +119,6 @@ impl PartialEq<dyn Any> for DistinctSum {
}
}
-/// A wrapper around a type to provide hash for floats
-#[derive(Copy, Clone)]
-struct Hashable<T>(T);
-
-impl<T: ToByteSlice> std::hash::Hash for Hashable<T> {
- fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
- self.0.to_byte_slice().hash(state)
- }
-}
-
-impl<T: ArrowNativeTypeOp> PartialEq for Hashable<T> {
- fn eq(&self, other: &Self) -> bool {
- self.0.is_eq(other.0)
- }
-}
-
-impl<T: ArrowNativeTypeOp> Eq for Hashable<T> {}
-
struct DistinctSumAccumulator<T: ArrowPrimitiveType> {
values: HashSet<Hashable<T::Native>, RandomState>,
data_type: DataType,
diff --git a/datafusion/physical-expr/src/aggregate/utils.rs
b/datafusion/physical-expr/src/aggregate/utils.rs
index 9777158da1..d73c46a0f6 100644
--- a/datafusion/physical-expr/src/aggregate/utils.rs
+++ b/datafusion/physical-expr/src/aggregate/utils.rs
@@ -28,7 +28,7 @@ use arrow_array::types::{
Decimal128Type, DecimalType, TimestampMicrosecondType,
TimestampMillisecondType,
TimestampNanosecondType, TimestampSecondType,
};
-use arrow_buffer::ArrowNativeType;
+use arrow_buffer::{ArrowNativeType, ToByteSlice};
use arrow_schema::{DataType, Field, SortOptions};
use datafusion_common::{exec_err, DataFusionError, Result};
use datafusion_expr::Accumulator;
@@ -211,3 +211,21 @@ pub(crate) fn ordering_fields(
pub fn get_sort_options(ordering_req: &[PhysicalSortExpr]) -> Vec<SortOptions>
{
ordering_req.iter().map(|item| item.options).collect()
}
+
+/// A wrapper around a type to provide hash for floats
+#[derive(Copy, Clone, Debug)]
+pub(crate) struct Hashable<T>(pub T);
+
+impl<T: ToByteSlice> std::hash::Hash for Hashable<T> {
+ fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+ self.0.to_byte_slice().hash(state)
+ }
+}
+
+impl<T: ArrowNativeTypeOp> PartialEq for Hashable<T> {
+ fn eq(&self, other: &Self) -> bool {
+ self.0.is_eq(other.0)
+ }
+}
+
+impl<T: ArrowNativeTypeOp> Eq for Hashable<T> {}