jayzhan211 commented on code in PR #11013:
URL: https://github.com/apache/datafusion/pull/11013#discussion_r1694066849


##########
datafusion/functions-aggregate/src/min_max.rs:
##########
@@ -123,170 +201,163 @@ macro_rules! instantiate_max_accumulator {
 ///
 /// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType
 macro_rules! instantiate_min_accumulator {
-    ($SELF:expr, $NATIVE:ident, $PRIMTYPE:ident) => {{
+    ($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{
         Ok(Box::new(
-            PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(
-                &$SELF.data_type,
-                |cur, new| {
-                    if *cur > new {
-                        *cur = new
-                    }
-                },
-            )
+            PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(&$DATA_TYPE, |cur, 
new| {
+                if *cur > new {
+                    *cur = new
+                }
+            })
             // Initialize each accumulator to $NATIVE::MAX
             .with_starting_value($NATIVE::MAX),
         ))
     }};
 }
 
-impl AggregateExpr for Max {
-    /// Return a reference to Any that can be used for downcasting
-    fn as_any(&self) -> &dyn Any {
+impl AggregateUDFImpl for Max {
+    fn as_any(&self) -> &dyn std::any::Any {
         self
     }
 
-    fn field(&self) -> Result<Field> {
-        Ok(Field::new(
-            &self.name,
-            self.data_type.clone(),
-            self.nullable,
-        ))
+    fn name(&self) -> &str {
+        "MAX"
     }
 
-    fn state_fields(&self) -> Result<Vec<Field>> {
-        Ok(vec![Field::new(
-            format_state_name(&self.name, "max"),
-            self.data_type.clone(),
-            true,
-        )])
+    fn signature(&self) -> &Signature {
+        &self.signature
     }
 
-    fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
-        vec![Arc::clone(&self.expr)]
+    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
+        type_coercion::aggregates::get_min_max_result_type(arg_types)?
+            .into_iter()
+            .next()
+            .ok_or_else(|| {
+                DataFusionError::Internal(format!(
+                    "Expected at one input type for MAX aggregate function"
+                ))
+            })
     }
 
-    fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
-        Ok(Box::new(MaxAccumulator::try_new(&self.data_type)?))
+    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn 
Accumulator>> {
+        // let data_type = &min_max_aggregate_data_type(acc_args.data_type);
+        let data_type = acc_args.input_type;
+        Ok(Box::new(MaxAccumulator::try_new(data_type)?))
     }
 
-    fn name(&self) -> &str {
-        &self.name
+    fn aliases(&self) -> &[String] {
+        &self.aliases
     }
 
-    fn groups_accumulator_supported(&self) -> bool {
-        use DataType::*;
+    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
+        // let data_type = min_max_aggregate_data_type(args.data_type);
+        let data_type = args.input_type;
         matches!(
-            self.data_type,
-            Int8 | Int16
-                | Int32
-                | Int64
-                | UInt8
-                | UInt16
-                | UInt32
-                | UInt64
-                | Float32
-                | Float64
-                | Decimal128(_, _)
-                | Decimal256(_, _)
-                | Date32
-                | Date64
-                | Time32(_)
-                | Time64(_)
-                | Timestamp(_, _)
+            data_type,
+            DataType::Int8
+                | DataType::Int16
+                | DataType::Int32
+                | DataType::Int64
+                | DataType::UInt8
+                | DataType::UInt16
+                | DataType::UInt32
+                | DataType::UInt64
+                | DataType::Float32
+                | DataType::Float64
+                | DataType::Decimal128(_, _)
+                | DataType::Decimal256(_, _)
+                | DataType::Date32
+                | DataType::Date64
+                | DataType::Time32(_)
+                | DataType::Time64(_)
+                | DataType::Timestamp(_, _)
         )
     }
 
-    fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
+    fn create_groups_accumulator(
+        &self,
+        args: AccumulatorArgs,
+    ) -> Result<Box<dyn GroupsAccumulator>> {
         use DataType::*;
         use TimeUnit::*;
-
-        match self.data_type {
-            Int8 => instantiate_max_accumulator!(self, i8, Int8Type),
-            Int16 => instantiate_max_accumulator!(self, i16, Int16Type),
-            Int32 => instantiate_max_accumulator!(self, i32, Int32Type),
-            Int64 => instantiate_max_accumulator!(self, i64, Int64Type),
-            UInt8 => instantiate_max_accumulator!(self, u8, UInt8Type),
-            UInt16 => instantiate_max_accumulator!(self, u16, UInt16Type),
-            UInt32 => instantiate_max_accumulator!(self, u32, UInt32Type),
-            UInt64 => instantiate_max_accumulator!(self, u64, UInt64Type),
+//        let data_type = min_max_aggregate_data_type(args.data_type);
+        let data_type = args.input_type;
+        match data_type {
+            Int8 => instantiate_max_accumulator!(data_type, i8, Int8Type),
+            Int16 => instantiate_max_accumulator!(data_type, i16, Int16Type),
+            Int32 => instantiate_max_accumulator!(data_type, i32, Int32Type),
+            Int64 => instantiate_max_accumulator!(data_type, i64, Int64Type),
+            UInt8 => instantiate_max_accumulator!(data_type, u8, UInt8Type),
+            UInt16 => instantiate_max_accumulator!(data_type, u16, UInt16Type),
+            UInt32 => instantiate_max_accumulator!(data_type, u32, UInt32Type),
+            UInt64 => instantiate_max_accumulator!(data_type, u64, UInt64Type),
             Float32 => {
-                instantiate_max_accumulator!(self, f32, Float32Type)
+                instantiate_max_accumulator!(data_type, f32, Float32Type)
             }
             Float64 => {
-                instantiate_max_accumulator!(self, f64, Float64Type)
+                instantiate_max_accumulator!(data_type, f64, Float64Type)
             }
-            Date32 => instantiate_max_accumulator!(self, i32, Date32Type),
-            Date64 => instantiate_max_accumulator!(self, i64, Date64Type),
+            Date32 => instantiate_max_accumulator!(data_type, i32, Date32Type),
+            Date64 => instantiate_max_accumulator!(data_type, i64, Date64Type),
             Time32(Second) => {
-                instantiate_max_accumulator!(self, i32, Time32SecondType)
+                instantiate_max_accumulator!(data_type, i32, Time32SecondType)
             }
             Time32(Millisecond) => {
-                instantiate_max_accumulator!(self, i32, Time32MillisecondType)
+                instantiate_max_accumulator!(data_type, i32, 
Time32MillisecondType)
             }
             Time64(Microsecond) => {
-                instantiate_max_accumulator!(self, i64, Time64MicrosecondType)
+                instantiate_max_accumulator!(data_type, i64, 
Time64MicrosecondType)
             }
             Time64(Nanosecond) => {
-                instantiate_max_accumulator!(self, i64, Time64NanosecondType)
+                instantiate_max_accumulator!(data_type, i64, 
Time64NanosecondType)
             }
             Timestamp(Second, _) => {
-                instantiate_max_accumulator!(self, i64, TimestampSecondType)
+                instantiate_max_accumulator!(data_type, i64, 
TimestampSecondType)
             }
             Timestamp(Millisecond, _) => {
-                instantiate_max_accumulator!(self, i64, 
TimestampMillisecondType)
+                instantiate_max_accumulator!(data_type, i64, 
TimestampMillisecondType)
             }
             Timestamp(Microsecond, _) => {
-                instantiate_max_accumulator!(self, i64, 
TimestampMicrosecondType)
+                instantiate_max_accumulator!(data_type, i64, 
TimestampMicrosecondType)
             }
             Timestamp(Nanosecond, _) => {
-                instantiate_max_accumulator!(self, i64, 
TimestampNanosecondType)
+                instantiate_max_accumulator!(data_type, i64, 
TimestampNanosecondType)
             }
             Decimal128(_, _) => {
-                instantiate_max_accumulator!(self, i128, Decimal128Type)
+                instantiate_max_accumulator!(data_type, i128, Decimal128Type)
             }
             Decimal256(_, _) => {
-                instantiate_max_accumulator!(self, i256, Decimal256Type)
+                instantiate_max_accumulator!(data_type, i256, Decimal256Type)
             }
 
             // It would be nice to have a fast implementation for Strings as 
well
             // https://github.com/apache/datafusion/issues/6906
 
             // This is only reached if groups_accumulator_supported is out of 
sync
-            _ => internal_err!(
-                "GroupsAccumulator not supported for max({})",
-                self.data_type
-            ),
+            _ => internal_err!("GroupsAccumulator not supported for max({})", 
data_type),
         }
     }
 
-    fn reverse_expr(&self) -> Option<Arc<dyn AggregateExpr>> {
-        Some(Arc::new(self.clone()))

Review Comment:
   You miss `reverse_udf` too.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org
For additional commands, e-mail: github-h...@datafusion.apache.org

Reply via email to