This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 5991ae3dd7 Minor: remove duplication in Min/Max accumulator (#6960)
5991ae3dd7 is described below
commit 5991ae3dd70bb9eec1e8efa9f7c733b2c78859d1
Author: Andrew Lamb <[email protected]>
AuthorDate: Fri Jul 14 06:09:06 2023 -0400
Minor: remove duplication in Min/Max accumulator (#6960)
---
datafusion/physical-expr/src/aggregate/min_max.rs | 376 +++++-----------------
1 file changed, 89 insertions(+), 287 deletions(-)
diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs
b/datafusion/physical-expr/src/aggregate/min_max.rs
index ebf317e6d0..cc230c174b 100644
--- a/datafusion/physical-expr/src/aggregate/min_max.rs
+++ b/datafusion/physical-expr/src/aggregate/min_max.rs
@@ -21,6 +21,7 @@ use std::any::Any;
use std::convert::TryFrom;
use std::sync::Arc;
+use crate::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr};
use arrow::compute;
use arrow::datatypes::{
@@ -39,16 +40,13 @@ use arrow::{
},
datatypes::Field,
};
-use arrow_array::cast::AsArray;
use arrow_array::types::{
Decimal128Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type,
Int8Type,
UInt16Type, UInt32Type, UInt64Type, UInt8Type,
};
-use arrow_array::{ArrowNumericType, PrimitiveArray};
use datafusion_common::ScalarValue;
use datafusion_common::{downcast_value, DataFusionError, Result};
use datafusion_expr::Accumulator;
-use log::debug;
use crate::aggregate::row_accumulator::{
is_row_accumulator_support_dtype, RowAccumulator,
@@ -59,9 +57,7 @@ use arrow::array::Array;
use arrow::array::Decimal128Array;
use datafusion_row::accessor::RowAccessor;
-use super::groups_accumulator::accumulate::NullState;
use super::moving_min_max;
-use super::utils::adjust_output_array;
// Min/max aggregation can take Dictionary encode input but always produces
unpacked
// (aka non Dictionary) output. We need to adjust the output data type to
reflect this.
@@ -99,13 +95,46 @@ impl Max {
}
}
}
+/// Creates a [`PrimitiveGroupsAccumulator`] for computing `MAX`
+/// the specified [`ArrowPrimitiveType`].
+///
+/// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType
+macro_rules! instantiate_max_accumulator {
+ ($SELF:expr, $NATIVE:ident, $PRIMTYPE:ident) => {{
+ Ok(Box::new(
+ PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(
+ &$SELF.data_type,
+ |cur, new| {
+ if *cur < new {
+ *cur = new
+ }
+ },
+ )
+ // Initialize each accumulator to $NATIVE::MIN
+ .with_starting_value($NATIVE::MIN),
+ ))
+ }};
+}
-macro_rules! instantiate_min_max_accumulator {
- ($SELF:expr, $NUMERICTYPE:ident, $MIN:expr) => {{
- Ok(Box::new(MinMaxGroupsPrimitiveAccumulator::<
- $NUMERICTYPE,
- $MIN,
- >::new(&$SELF.data_type)))
+/// Creates a [`PrimitiveGroupsAccumulator`] for computing `MIN`
+/// the specified [`ArrowPrimitiveType`].
+///
+///
+/// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType
+macro_rules! instantiate_min_accumulator {
+ ($SELF:expr, $NATIVE:ident, $PRIMTYPE:ident) => {{
+ Ok(Box::new(
+ PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(
+ &$SELF.data_type,
+ |cur, new| {
+ if *cur > new {
+ *cur = new
+ }
+ },
+ )
+ // Initialize each accumulator to $NATIVE::MAX
+ .with_starting_value($NATIVE::MAX),
+ ))
}};
}
@@ -184,56 +213,56 @@ impl AggregateExpr for Max {
use TimeUnit::*;
match self.data_type {
- Int8 => instantiate_min_max_accumulator!(self, Int8Type, false),
- Int16 => instantiate_min_max_accumulator!(self, Int16Type, false),
- Int32 => instantiate_min_max_accumulator!(self, Int32Type, false),
- Int64 => instantiate_min_max_accumulator!(self, Int64Type, false),
- UInt8 => instantiate_min_max_accumulator!(self, UInt8Type, false),
- UInt16 => instantiate_min_max_accumulator!(self, UInt16Type,
false),
- UInt32 => instantiate_min_max_accumulator!(self, UInt32Type,
false),
- UInt64 => instantiate_min_max_accumulator!(self, UInt64Type,
false),
+ Int8 => instantiate_max_accumulator!(self, i8, Int8Type),
+ Int16 => instantiate_max_accumulator!(self, i16, Int16Type),
+ Int32 => instantiate_max_accumulator!(self, i32, Int32Type),
+ Int64 => instantiate_max_accumulator!(self, i64, Int64Type),
+ UInt8 => instantiate_max_accumulator!(self, u8, UInt8Type),
+ UInt16 => instantiate_max_accumulator!(self, u16, UInt16Type),
+ UInt32 => instantiate_max_accumulator!(self, u32, UInt32Type),
+ UInt64 => instantiate_max_accumulator!(self, u64, UInt64Type),
Float32 => {
- instantiate_min_max_accumulator!(self, Float32Type, false)
+ instantiate_max_accumulator!(self, f32, Float32Type)
}
Float64 => {
- instantiate_min_max_accumulator!(self, Float64Type, false)
+ instantiate_max_accumulator!(self, f64, Float64Type)
}
- Date32 => instantiate_min_max_accumulator!(self, Date32Type,
false),
- Date64 => instantiate_min_max_accumulator!(self, Date64Type,
false),
+ Date32 => instantiate_max_accumulator!(self, i32, Date32Type),
+ Date64 => instantiate_max_accumulator!(self, i64, Date64Type),
Time32(Second) => {
- instantiate_min_max_accumulator!(self, Time32SecondType, false)
+ instantiate_max_accumulator!(self, i32, Time32SecondType)
}
Time32(Millisecond) => {
- instantiate_min_max_accumulator!(self, Time32MillisecondType,
false)
+ instantiate_max_accumulator!(self, i32, Time32MillisecondType)
}
Time64(Microsecond) => {
- instantiate_min_max_accumulator!(self, Time64MicrosecondType,
false)
+ instantiate_max_accumulator!(self, i64, Time64MicrosecondType)
}
Time64(Nanosecond) => {
- instantiate_min_max_accumulator!(self, Time64NanosecondType,
false)
+ instantiate_max_accumulator!(self, i64, Time64NanosecondType)
}
Timestamp(Second, _) => {
- instantiate_min_max_accumulator!(self, TimestampSecondType,
false)
+ instantiate_max_accumulator!(self, i64, TimestampSecondType)
}
Timestamp(Millisecond, _) => {
- instantiate_min_max_accumulator!(self,
TimestampMillisecondType, false)
+ instantiate_max_accumulator!(self, i64,
TimestampMillisecondType)
}
Timestamp(Microsecond, _) => {
- instantiate_min_max_accumulator!(self,
TimestampMicrosecondType, false)
+ instantiate_max_accumulator!(self, i64,
TimestampMicrosecondType)
}
Timestamp(Nanosecond, _) => {
- instantiate_min_max_accumulator!(self,
TimestampNanosecondType, false)
+ instantiate_max_accumulator!(self, i64,
TimestampNanosecondType)
+ }
+ Decimal128(_, _) => {
+ instantiate_max_accumulator!(self, i128, Decimal128Type)
}
// It would be nice to have a fast implementation for Strings as
well
// https://github.com/apache/arrow-datafusion/issues/6906
- Decimal128(_, _) => Ok(Box::new(MinMaxGroupsPrimitiveAccumulator::<
- Decimal128Type,
- false,
- >::new(&self.data_type))),
+
// This is only reached if groups_accumulator_supported is out of
sync
_ => Err(DataFusionError::Internal(format!(
- "MinMaxGroupsPrimitiveAccumulator not supported for max({})",
+ "GroupsAccumulator not supported for max({})",
self.data_type
))),
}
@@ -965,53 +994,52 @@ impl AggregateExpr for Min {
use DataType::*;
use TimeUnit::*;
match self.data_type {
- Int8 => instantiate_min_max_accumulator!(self, Int8Type, true),
- Int16 => instantiate_min_max_accumulator!(self, Int16Type, true),
- Int32 => instantiate_min_max_accumulator!(self, Int32Type, true),
- Int64 => instantiate_min_max_accumulator!(self, Int64Type, true),
- UInt8 => instantiate_min_max_accumulator!(self, UInt8Type, true),
- UInt16 => instantiate_min_max_accumulator!(self, UInt16Type, true),
- UInt32 => instantiate_min_max_accumulator!(self, UInt32Type, true),
- UInt64 => instantiate_min_max_accumulator!(self, UInt64Type, true),
+ Int8 => instantiate_min_accumulator!(self, i8, Int8Type),
+ Int16 => instantiate_min_accumulator!(self, i16, Int16Type),
+ Int32 => instantiate_min_accumulator!(self, i32, Int32Type),
+ Int64 => instantiate_min_accumulator!(self, i64, Int64Type),
+ UInt8 => instantiate_min_accumulator!(self, u8, UInt8Type),
+ UInt16 => instantiate_min_accumulator!(self, u16, UInt16Type),
+ UInt32 => instantiate_min_accumulator!(self, u32, UInt32Type),
+ UInt64 => instantiate_min_accumulator!(self, u64, UInt64Type),
Float32 => {
- instantiate_min_max_accumulator!(self, Float32Type, true)
+ instantiate_min_accumulator!(self, f32, Float32Type)
}
Float64 => {
- instantiate_min_max_accumulator!(self, Float64Type, true)
+ instantiate_min_accumulator!(self, f64, Float64Type)
}
- Date32 => instantiate_min_max_accumulator!(self, Date32Type, true),
- Date64 => instantiate_min_max_accumulator!(self, Date64Type, true),
+ Date32 => instantiate_min_accumulator!(self, i32, Date32Type),
+ Date64 => instantiate_min_accumulator!(self, i64, Date64Type),
Time32(Second) => {
- instantiate_min_max_accumulator!(self, Time32SecondType, true)
+ instantiate_min_accumulator!(self, i32, Time32SecondType)
}
Time32(Millisecond) => {
- instantiate_min_max_accumulator!(self, Time32MillisecondType,
true)
+ instantiate_min_accumulator!(self, i32, Time32MillisecondType)
}
Time64(Microsecond) => {
- instantiate_min_max_accumulator!(self, Time64MicrosecondType,
true)
+ instantiate_min_accumulator!(self, i64, Time64MicrosecondType)
}
Time64(Nanosecond) => {
- instantiate_min_max_accumulator!(self, Time64NanosecondType,
true)
+ instantiate_min_accumulator!(self, i64, Time64NanosecondType)
}
Timestamp(Second, _) => {
- instantiate_min_max_accumulator!(self, TimestampSecondType,
true)
+ instantiate_min_accumulator!(self, i64, TimestampSecondType)
}
Timestamp(Millisecond, _) => {
- instantiate_min_max_accumulator!(self,
TimestampMillisecondType, true)
+ instantiate_min_accumulator!(self, i64,
TimestampMillisecondType)
}
Timestamp(Microsecond, _) => {
- instantiate_min_max_accumulator!(self,
TimestampMicrosecondType, true)
+ instantiate_min_accumulator!(self, i64,
TimestampMicrosecondType)
}
Timestamp(Nanosecond, _) => {
- instantiate_min_max_accumulator!(self,
TimestampNanosecondType, true)
+ instantiate_min_accumulator!(self, i64,
TimestampNanosecondType)
+ }
+ Decimal128(_, _) => {
+ instantiate_min_accumulator!(self, i128, Decimal128Type)
}
- Decimal128(_, _) => Ok(Box::new(MinMaxGroupsPrimitiveAccumulator::<
- Decimal128Type,
- true,
- >::new(&self.data_type))),
// This is only reached if groups_accumulator_supported is out of
sync
_ => Err(DataFusionError::Internal(format!(
- "MinMaxGroupsPrimitiveAccumulator not supported for min({})",
+ "GroupsAccumulator not supported for min({})",
self.data_type
))),
}
@@ -1204,232 +1232,6 @@ impl RowAccumulator for MinRowAccumulator {
}
}
-trait MinMax {
- fn min() -> Self;
- fn max() -> Self;
-}
-
-impl MinMax for u8 {
- fn min() -> Self {
- u8::MIN
- }
- fn max() -> Self {
- u8::MAX
- }
-}
-impl MinMax for i8 {
- fn min() -> Self {
- i8::MIN
- }
- fn max() -> Self {
- i8::MAX
- }
-}
-impl MinMax for u16 {
- fn min() -> Self {
- u16::MIN
- }
- fn max() -> Self {
- u16::MAX
- }
-}
-impl MinMax for i16 {
- fn min() -> Self {
- i16::MIN
- }
- fn max() -> Self {
- i16::MAX
- }
-}
-impl MinMax for u32 {
- fn min() -> Self {
- u32::MIN
- }
- fn max() -> Self {
- u32::MAX
- }
-}
-impl MinMax for i32 {
- fn min() -> Self {
- i32::MIN
- }
- fn max() -> Self {
- i32::MAX
- }
-}
-impl MinMax for i64 {
- fn min() -> Self {
- i64::MIN
- }
- fn max() -> Self {
- i64::MAX
- }
-}
-impl MinMax for u64 {
- fn min() -> Self {
- u64::MIN
- }
- fn max() -> Self {
- u64::MAX
- }
-}
-impl MinMax for f32 {
- fn min() -> Self {
- f32::MIN
- }
- fn max() -> Self {
- f32::MAX
- }
-}
-impl MinMax for f64 {
- fn min() -> Self {
- f64::MIN
- }
- fn max() -> Self {
- f64::MAX
- }
-}
-impl MinMax for i128 {
- fn min() -> Self {
- i128::MIN
- }
- fn max() -> Self {
- i128::MAX
- }
-}
-
-/// An accumulator to compute the min or max of a [`PrimitiveArray<T>`].
-///
-/// Stores values as native/primitive type
-///
-/// Note this doesn't use [`PrimitiveGroupsAccumulator`] because it
-/// needs to control the default accumulator value (which is not
-/// `default::Default()`)
-///
-/// [`PrimitiveGroupsAccumulator`]:
crate::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator
-#[derive(Debug)]
-struct MinMaxGroupsPrimitiveAccumulator<T, const MIN: bool>
-where
- T: ArrowNumericType + Send,
- T::Native: MinMax,
-{
- /// Min/max per group, stored as the native type
- min_max: Vec<T::Native>,
-
- /// Track nulls in the input / filters
- null_state: NullState,
-
- /// The output datatype (needed for decimal precision/scale)
- data_type: DataType,
-}
-
-impl<T, const MIN: bool> MinMaxGroupsPrimitiveAccumulator<T, MIN>
-where
- T: ArrowNumericType + Send,
- T::Native: MinMax,
-{
- pub fn new(data_type: &DataType) -> Self {
- debug!(
- "MinMaxGroupsPrimitiveAccumulator ({}, {})",
- std::any::type_name::<T>(),
- MIN,
- );
-
- Self {
- min_max: vec![],
- null_state: NullState::new(),
- data_type: data_type.clone(),
- }
- }
-}
-
-impl<T, const MIN: bool> GroupsAccumulator for
MinMaxGroupsPrimitiveAccumulator<T, MIN>
-where
- T: ArrowNumericType + Send,
- T::Native: MinMax,
-{
- fn update_batch(
- &mut self,
- values: &[ArrayRef],
- group_indices: &[usize],
- opt_filter: Option<&arrow_array::BooleanArray>,
- total_num_groups: usize,
- ) -> Result<()> {
- assert_eq!(values.len(), 1, "single argument to update_batch");
- let values = values[0].as_primitive::<T>();
-
- self.min_max.resize(
- total_num_groups,
- if MIN {
- T::Native::max()
- } else {
- T::Native::min()
- },
- );
-
- // NullState dispatches / handles tracking nulls and groups that saw
no values
- self.null_state.accumulate(
- group_indices,
- values,
- opt_filter,
- total_num_groups,
- |group_index, new_value| {
- let val = &mut self.min_max[group_index];
- match MIN {
- true => {
- if new_value < *val {
- *val = new_value;
- }
- }
- false => {
- if new_value > *val {
- *val = new_value;
- }
- }
- }
- },
- );
-
- Ok(())
- }
-
- fn merge_batch(
- &mut self,
- values: &[ArrayRef],
- group_indices: &[usize],
- opt_filter: Option<&arrow_array::BooleanArray>,
- total_num_groups: usize,
- ) -> Result<()> {
- Self::update_batch(self, values, group_indices, opt_filter,
total_num_groups)
- }
-
- fn evaluate(&mut self) -> Result<ArrayRef> {
- let min_max = std::mem::take(&mut self.min_max);
- let nulls = self.null_state.build();
-
- let min_max = PrimitiveArray::<T>::new(min_max.into(), Some(nulls));
// no copy
- let min_max = adjust_output_array(&self.data_type, Arc::new(min_max))?;
-
- Ok(Arc::new(min_max))
- }
-
- // return arrays for min/max values
- fn state(&mut self) -> Result<Vec<ArrayRef>> {
- let nulls = self.null_state.build();
-
- let min_max = std::mem::take(&mut self.min_max);
- let min_max = PrimitiveArray::<T>::new(min_max.into(), Some(nulls));
// zero copy
-
- let min_max = adjust_output_array(&self.data_type, Arc::new(min_max))?;
-
- Ok(vec![min_max])
- }
-
- fn size(&self) -> usize {
- self.min_max.capacity() * std::mem::size_of::<T>() +
self.null_state.size()
- }
-}
-
#[cfg(test)]
mod tests {
use super::*;