This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 5991ae3dd7 Minor: remove duplication in Min/Max accumulator (#6960)
5991ae3dd7 is described below

commit 5991ae3dd70bb9eec1e8efa9f7c733b2c78859d1
Author: Andrew Lamb <[email protected]>
AuthorDate: Fri Jul 14 06:09:06 2023 -0400

    Minor: remove duplication in Min/Max accumulator (#6960)
---
 datafusion/physical-expr/src/aggregate/min_max.rs | 376 +++++-----------------
 1 file changed, 89 insertions(+), 287 deletions(-)

diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs 
b/datafusion/physical-expr/src/aggregate/min_max.rs
index ebf317e6d0..cc230c174b 100644
--- a/datafusion/physical-expr/src/aggregate/min_max.rs
+++ b/datafusion/physical-expr/src/aggregate/min_max.rs
@@ -21,6 +21,7 @@ use std::any::Any;
 use std::convert::TryFrom;
 use std::sync::Arc;
 
+use crate::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
 use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr};
 use arrow::compute;
 use arrow::datatypes::{
@@ -39,16 +40,13 @@ use arrow::{
     },
     datatypes::Field,
 };
-use arrow_array::cast::AsArray;
 use arrow_array::types::{
     Decimal128Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, 
Int8Type,
     UInt16Type, UInt32Type, UInt64Type, UInt8Type,
 };
-use arrow_array::{ArrowNumericType, PrimitiveArray};
 use datafusion_common::ScalarValue;
 use datafusion_common::{downcast_value, DataFusionError, Result};
 use datafusion_expr::Accumulator;
-use log::debug;
 
 use crate::aggregate::row_accumulator::{
     is_row_accumulator_support_dtype, RowAccumulator,
@@ -59,9 +57,7 @@ use arrow::array::Array;
 use arrow::array::Decimal128Array;
 use datafusion_row::accessor::RowAccessor;
 
-use super::groups_accumulator::accumulate::NullState;
 use super::moving_min_max;
-use super::utils::adjust_output_array;
 
 // Min/max aggregation can take Dictionary encode input but always produces 
unpacked
 // (aka non Dictionary) output. We need to adjust the output data type to 
reflect this.
@@ -99,13 +95,46 @@ impl Max {
         }
     }
 }
+/// Creates a [`PrimitiveGroupsAccumulator`] for computing `MAX`
+/// the specified [`ArrowPrimitiveType`].
+///
+/// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType
+macro_rules! instantiate_max_accumulator {
+    ($SELF:expr, $NATIVE:ident, $PRIMTYPE:ident) => {{
+        Ok(Box::new(
+            PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(
+                &$SELF.data_type,
+                |cur, new| {
+                    if *cur < new {
+                        *cur = new
+                    }
+                },
+            )
+            // Initialize each accumulator to $NATIVE::MIN
+            .with_starting_value($NATIVE::MIN),
+        ))
+    }};
+}
 
-macro_rules! instantiate_min_max_accumulator {
-    ($SELF:expr, $NUMERICTYPE:ident, $MIN:expr) => {{
-        Ok(Box::new(MinMaxGroupsPrimitiveAccumulator::<
-            $NUMERICTYPE,
-            $MIN,
-        >::new(&$SELF.data_type)))
+/// Creates a [`PrimitiveGroupsAccumulator`] for computing `MIN`
+/// the specified [`ArrowPrimitiveType`].
+///
+///
+/// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType
+macro_rules! instantiate_min_accumulator {
+    ($SELF:expr, $NATIVE:ident, $PRIMTYPE:ident) => {{
+        Ok(Box::new(
+            PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(
+                &$SELF.data_type,
+                |cur, new| {
+                    if *cur > new {
+                        *cur = new
+                    }
+                },
+            )
+            // Initialize each accumulator to $NATIVE::MAX
+            .with_starting_value($NATIVE::MAX),
+        ))
     }};
 }
 
@@ -184,56 +213,56 @@ impl AggregateExpr for Max {
         use TimeUnit::*;
 
         match self.data_type {
-            Int8 => instantiate_min_max_accumulator!(self, Int8Type, false),
-            Int16 => instantiate_min_max_accumulator!(self, Int16Type, false),
-            Int32 => instantiate_min_max_accumulator!(self, Int32Type, false),
-            Int64 => instantiate_min_max_accumulator!(self, Int64Type, false),
-            UInt8 => instantiate_min_max_accumulator!(self, UInt8Type, false),
-            UInt16 => instantiate_min_max_accumulator!(self, UInt16Type, 
false),
-            UInt32 => instantiate_min_max_accumulator!(self, UInt32Type, 
false),
-            UInt64 => instantiate_min_max_accumulator!(self, UInt64Type, 
false),
+            Int8 => instantiate_max_accumulator!(self, i8, Int8Type),
+            Int16 => instantiate_max_accumulator!(self, i16, Int16Type),
+            Int32 => instantiate_max_accumulator!(self, i32, Int32Type),
+            Int64 => instantiate_max_accumulator!(self, i64, Int64Type),
+            UInt8 => instantiate_max_accumulator!(self, u8, UInt8Type),
+            UInt16 => instantiate_max_accumulator!(self, u16, UInt16Type),
+            UInt32 => instantiate_max_accumulator!(self, u32, UInt32Type),
+            UInt64 => instantiate_max_accumulator!(self, u64, UInt64Type),
             Float32 => {
-                instantiate_min_max_accumulator!(self, Float32Type, false)
+                instantiate_max_accumulator!(self, f32, Float32Type)
             }
             Float64 => {
-                instantiate_min_max_accumulator!(self, Float64Type, false)
+                instantiate_max_accumulator!(self, f64, Float64Type)
             }
-            Date32 => instantiate_min_max_accumulator!(self, Date32Type, 
false),
-            Date64 => instantiate_min_max_accumulator!(self, Date64Type, 
false),
+            Date32 => instantiate_max_accumulator!(self, i32, Date32Type),
+            Date64 => instantiate_max_accumulator!(self, i64, Date64Type),
             Time32(Second) => {
-                instantiate_min_max_accumulator!(self, Time32SecondType, false)
+                instantiate_max_accumulator!(self, i32, Time32SecondType)
             }
             Time32(Millisecond) => {
-                instantiate_min_max_accumulator!(self, Time32MillisecondType, 
false)
+                instantiate_max_accumulator!(self, i32, Time32MillisecondType)
             }
             Time64(Microsecond) => {
-                instantiate_min_max_accumulator!(self, Time64MicrosecondType, 
false)
+                instantiate_max_accumulator!(self, i64, Time64MicrosecondType)
             }
             Time64(Nanosecond) => {
-                instantiate_min_max_accumulator!(self, Time64NanosecondType, 
false)
+                instantiate_max_accumulator!(self, i64, Time64NanosecondType)
             }
             Timestamp(Second, _) => {
-                instantiate_min_max_accumulator!(self, TimestampSecondType, 
false)
+                instantiate_max_accumulator!(self, i64, TimestampSecondType)
             }
             Timestamp(Millisecond, _) => {
-                instantiate_min_max_accumulator!(self, 
TimestampMillisecondType, false)
+                instantiate_max_accumulator!(self, i64, 
TimestampMillisecondType)
             }
             Timestamp(Microsecond, _) => {
-                instantiate_min_max_accumulator!(self, 
TimestampMicrosecondType, false)
+                instantiate_max_accumulator!(self, i64, 
TimestampMicrosecondType)
             }
             Timestamp(Nanosecond, _) => {
-                instantiate_min_max_accumulator!(self, 
TimestampNanosecondType, false)
+                instantiate_max_accumulator!(self, i64, 
TimestampNanosecondType)
+            }
+            Decimal128(_, _) => {
+                instantiate_max_accumulator!(self, i128, Decimal128Type)
             }
 
             // It would be nice to have a fast implementation for Strings as 
well
             // https://github.com/apache/arrow-datafusion/issues/6906
-            Decimal128(_, _) => Ok(Box::new(MinMaxGroupsPrimitiveAccumulator::<
-                Decimal128Type,
-                false,
-            >::new(&self.data_type))),
+
             // This is only reached if groups_accumulator_supported is out of 
sync
             _ => Err(DataFusionError::Internal(format!(
-                "MinMaxGroupsPrimitiveAccumulator not supported for max({})",
+                "GroupsAccumulator not supported for max({})",
                 self.data_type
             ))),
         }
@@ -965,53 +994,52 @@ impl AggregateExpr for Min {
         use DataType::*;
         use TimeUnit::*;
         match self.data_type {
-            Int8 => instantiate_min_max_accumulator!(self, Int8Type, true),
-            Int16 => instantiate_min_max_accumulator!(self, Int16Type, true),
-            Int32 => instantiate_min_max_accumulator!(self, Int32Type, true),
-            Int64 => instantiate_min_max_accumulator!(self, Int64Type, true),
-            UInt8 => instantiate_min_max_accumulator!(self, UInt8Type, true),
-            UInt16 => instantiate_min_max_accumulator!(self, UInt16Type, true),
-            UInt32 => instantiate_min_max_accumulator!(self, UInt32Type, true),
-            UInt64 => instantiate_min_max_accumulator!(self, UInt64Type, true),
+            Int8 => instantiate_min_accumulator!(self, i8, Int8Type),
+            Int16 => instantiate_min_accumulator!(self, i16, Int16Type),
+            Int32 => instantiate_min_accumulator!(self, i32, Int32Type),
+            Int64 => instantiate_min_accumulator!(self, i64, Int64Type),
+            UInt8 => instantiate_min_accumulator!(self, u8, UInt8Type),
+            UInt16 => instantiate_min_accumulator!(self, u16, UInt16Type),
+            UInt32 => instantiate_min_accumulator!(self, u32, UInt32Type),
+            UInt64 => instantiate_min_accumulator!(self, u64, UInt64Type),
             Float32 => {
-                instantiate_min_max_accumulator!(self, Float32Type, true)
+                instantiate_min_accumulator!(self, f32, Float32Type)
             }
             Float64 => {
-                instantiate_min_max_accumulator!(self, Float64Type, true)
+                instantiate_min_accumulator!(self, f64, Float64Type)
             }
-            Date32 => instantiate_min_max_accumulator!(self, Date32Type, true),
-            Date64 => instantiate_min_max_accumulator!(self, Date64Type, true),
+            Date32 => instantiate_min_accumulator!(self, i32, Date32Type),
+            Date64 => instantiate_min_accumulator!(self, i64, Date64Type),
             Time32(Second) => {
-                instantiate_min_max_accumulator!(self, Time32SecondType, true)
+                instantiate_min_accumulator!(self, i32, Time32SecondType)
             }
             Time32(Millisecond) => {
-                instantiate_min_max_accumulator!(self, Time32MillisecondType, 
true)
+                instantiate_min_accumulator!(self, i32, Time32MillisecondType)
             }
             Time64(Microsecond) => {
-                instantiate_min_max_accumulator!(self, Time64MicrosecondType, 
true)
+                instantiate_min_accumulator!(self, i64, Time64MicrosecondType)
             }
             Time64(Nanosecond) => {
-                instantiate_min_max_accumulator!(self, Time64NanosecondType, 
true)
+                instantiate_min_accumulator!(self, i64, Time64NanosecondType)
             }
             Timestamp(Second, _) => {
-                instantiate_min_max_accumulator!(self, TimestampSecondType, 
true)
+                instantiate_min_accumulator!(self, i64, TimestampSecondType)
             }
             Timestamp(Millisecond, _) => {
-                instantiate_min_max_accumulator!(self, 
TimestampMillisecondType, true)
+                instantiate_min_accumulator!(self, i64, 
TimestampMillisecondType)
             }
             Timestamp(Microsecond, _) => {
-                instantiate_min_max_accumulator!(self, 
TimestampMicrosecondType, true)
+                instantiate_min_accumulator!(self, i64, 
TimestampMicrosecondType)
             }
             Timestamp(Nanosecond, _) => {
-                instantiate_min_max_accumulator!(self, 
TimestampNanosecondType, true)
+                instantiate_min_accumulator!(self, i64, 
TimestampNanosecondType)
+            }
+            Decimal128(_, _) => {
+                instantiate_min_accumulator!(self, i128, Decimal128Type)
             }
-            Decimal128(_, _) => Ok(Box::new(MinMaxGroupsPrimitiveAccumulator::<
-                Decimal128Type,
-                true,
-            >::new(&self.data_type))),
             // This is only reached if groups_accumulator_supported is out of 
sync
             _ => Err(DataFusionError::Internal(format!(
-                "MinMaxGroupsPrimitiveAccumulator not supported for min({})",
+                "GroupsAccumulator not supported for min({})",
                 self.data_type
             ))),
         }
@@ -1204,232 +1232,6 @@ impl RowAccumulator for MinRowAccumulator {
     }
 }
 
-trait MinMax {
-    fn min() -> Self;
-    fn max() -> Self;
-}
-
-impl MinMax for u8 {
-    fn min() -> Self {
-        u8::MIN
-    }
-    fn max() -> Self {
-        u8::MAX
-    }
-}
-impl MinMax for i8 {
-    fn min() -> Self {
-        i8::MIN
-    }
-    fn max() -> Self {
-        i8::MAX
-    }
-}
-impl MinMax for u16 {
-    fn min() -> Self {
-        u16::MIN
-    }
-    fn max() -> Self {
-        u16::MAX
-    }
-}
-impl MinMax for i16 {
-    fn min() -> Self {
-        i16::MIN
-    }
-    fn max() -> Self {
-        i16::MAX
-    }
-}
-impl MinMax for u32 {
-    fn min() -> Self {
-        u32::MIN
-    }
-    fn max() -> Self {
-        u32::MAX
-    }
-}
-impl MinMax for i32 {
-    fn min() -> Self {
-        i32::MIN
-    }
-    fn max() -> Self {
-        i32::MAX
-    }
-}
-impl MinMax for i64 {
-    fn min() -> Self {
-        i64::MIN
-    }
-    fn max() -> Self {
-        i64::MAX
-    }
-}
-impl MinMax for u64 {
-    fn min() -> Self {
-        u64::MIN
-    }
-    fn max() -> Self {
-        u64::MAX
-    }
-}
-impl MinMax for f32 {
-    fn min() -> Self {
-        f32::MIN
-    }
-    fn max() -> Self {
-        f32::MAX
-    }
-}
-impl MinMax for f64 {
-    fn min() -> Self {
-        f64::MIN
-    }
-    fn max() -> Self {
-        f64::MAX
-    }
-}
-impl MinMax for i128 {
-    fn min() -> Self {
-        i128::MIN
-    }
-    fn max() -> Self {
-        i128::MAX
-    }
-}
-
-/// An accumulator to compute the min or max of a [`PrimitiveArray<T>`].
-///
-/// Stores values as native/primitive type
-///
-/// Note this doesn't use [`PrimitiveGroupsAccumulator`] because it
-/// needs to control the default accumulator value (which is not
-/// `default::Default()`)
-///
-/// [`PrimitiveGroupsAccumulator`]: 
crate::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator
-#[derive(Debug)]
-struct MinMaxGroupsPrimitiveAccumulator<T, const MIN: bool>
-where
-    T: ArrowNumericType + Send,
-    T::Native: MinMax,
-{
-    /// Min/max per group, stored as the native type
-    min_max: Vec<T::Native>,
-
-    /// Track nulls in the input / filters
-    null_state: NullState,
-
-    /// The output datatype (needed for decimal precision/scale)
-    data_type: DataType,
-}
-
-impl<T, const MIN: bool> MinMaxGroupsPrimitiveAccumulator<T, MIN>
-where
-    T: ArrowNumericType + Send,
-    T::Native: MinMax,
-{
-    pub fn new(data_type: &DataType) -> Self {
-        debug!(
-            "MinMaxGroupsPrimitiveAccumulator ({}, {})",
-            std::any::type_name::<T>(),
-            MIN,
-        );
-
-        Self {
-            min_max: vec![],
-            null_state: NullState::new(),
-            data_type: data_type.clone(),
-        }
-    }
-}
-
-impl<T, const MIN: bool> GroupsAccumulator for 
MinMaxGroupsPrimitiveAccumulator<T, MIN>
-where
-    T: ArrowNumericType + Send,
-    T::Native: MinMax,
-{
-    fn update_batch(
-        &mut self,
-        values: &[ArrayRef],
-        group_indices: &[usize],
-        opt_filter: Option<&arrow_array::BooleanArray>,
-        total_num_groups: usize,
-    ) -> Result<()> {
-        assert_eq!(values.len(), 1, "single argument to update_batch");
-        let values = values[0].as_primitive::<T>();
-
-        self.min_max.resize(
-            total_num_groups,
-            if MIN {
-                T::Native::max()
-            } else {
-                T::Native::min()
-            },
-        );
-
-        // NullState dispatches / handles tracking nulls and groups that saw 
no values
-        self.null_state.accumulate(
-            group_indices,
-            values,
-            opt_filter,
-            total_num_groups,
-            |group_index, new_value| {
-                let val = &mut self.min_max[group_index];
-                match MIN {
-                    true => {
-                        if new_value < *val {
-                            *val = new_value;
-                        }
-                    }
-                    false => {
-                        if new_value > *val {
-                            *val = new_value;
-                        }
-                    }
-                }
-            },
-        );
-
-        Ok(())
-    }
-
-    fn merge_batch(
-        &mut self,
-        values: &[ArrayRef],
-        group_indices: &[usize],
-        opt_filter: Option<&arrow_array::BooleanArray>,
-        total_num_groups: usize,
-    ) -> Result<()> {
-        Self::update_batch(self, values, group_indices, opt_filter, 
total_num_groups)
-    }
-
-    fn evaluate(&mut self) -> Result<ArrayRef> {
-        let min_max = std::mem::take(&mut self.min_max);
-        let nulls = self.null_state.build();
-
-        let min_max = PrimitiveArray::<T>::new(min_max.into(), Some(nulls)); 
// no copy
-        let min_max = adjust_output_array(&self.data_type, Arc::new(min_max))?;
-
-        Ok(Arc::new(min_max))
-    }
-
-    // return arrays for min/max values
-    fn state(&mut self) -> Result<Vec<ArrayRef>> {
-        let nulls = self.null_state.build();
-
-        let min_max = std::mem::take(&mut self.min_max);
-        let min_max = PrimitiveArray::<T>::new(min_max.into(), Some(nulls)); 
// zero copy
-
-        let min_max = adjust_output_array(&self.data_type, Arc::new(min_max))?;
-
-        Ok(vec![min_max])
-    }
-
-    fn size(&self) -> usize {
-        self.min_max.capacity() * std::mem::size_of::<T>() + 
self.null_state.size()
-    }
-}
-
 #[cfg(test)]
 mod tests {
     use super::*;

Reply via email to