jayzhan211 commented on code in PR #11013: URL: https://github.com/apache/datafusion/pull/11013#discussion_r1694066849
########## datafusion/functions-aggregate/src/min_max.rs: ########## @@ -123,170 +201,163 @@ macro_rules! instantiate_max_accumulator { /// /// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType macro_rules! instantiate_min_accumulator { - ($SELF:expr, $NATIVE:ident, $PRIMTYPE:ident) => {{ + ($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{ Ok(Box::new( - PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new( - &$SELF.data_type, - |cur, new| { - if *cur > new { - *cur = new - } - }, - ) + PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(&$DATA_TYPE, |cur, new| { + if *cur > new { + *cur = new + } + }) // Initialize each accumulator to $NATIVE::MAX .with_starting_value($NATIVE::MAX), )) }}; } -impl AggregateExpr for Max { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { +impl AggregateUDFImpl for Max { + fn as_any(&self) -> &dyn std::any::Any { self } - fn field(&self) -> Result<Field> { - Ok(Field::new( - &self.name, - self.data_type.clone(), - self.nullable, - )) + fn name(&self) -> &str { + "MAX" } - fn state_fields(&self) -> Result<Vec<Field>> { - Ok(vec![Field::new( - format_state_name(&self.name, "max"), - self.data_type.clone(), - true, - )]) + fn signature(&self) -> &Signature { + &self.signature } - fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> { - vec![Arc::clone(&self.expr)] + fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { + type_coercion::aggregates::get_min_max_result_type(arg_types)? + .into_iter() + .next() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "Expected at one input type for MAX aggregate function" + )) + }) } - fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> { - Ok(Box::new(MaxAccumulator::try_new(&self.data_type)?)) + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { + // let data_type = &min_max_aggregate_data_type(acc_args.data_type); + let data_type = acc_args.input_type; + Ok(Box::new(MaxAccumulator::try_new(data_type)?)) } - fn name(&self) -> &str { - &self.name + fn aliases(&self) -> &[String] { + &self.aliases } - fn groups_accumulator_supported(&self) -> bool { - use DataType::*; + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + // let data_type = min_max_aggregate_data_type(args.data_type); + let data_type = args.input_type; matches!( - self.data_type, - Int8 | Int16 - | Int32 - | Int64 - | UInt8 - | UInt16 - | UInt32 - | UInt64 - | Float32 - | Float64 - | Decimal128(_, _) - | Decimal256(_, _) - | Date32 - | Date64 - | Time32(_) - | Time64(_) - | Timestamp(_, _) + data_type, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + | DataType::Date32 + | DataType::Date64 + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Timestamp(_, _) ) } - fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> { + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result<Box<dyn GroupsAccumulator>> { use DataType::*; use TimeUnit::*; - - match self.data_type { - Int8 => instantiate_max_accumulator!(self, i8, Int8Type), - Int16 => instantiate_max_accumulator!(self, i16, Int16Type), - Int32 => instantiate_max_accumulator!(self, i32, Int32Type), - Int64 => instantiate_max_accumulator!(self, i64, Int64Type), - UInt8 => instantiate_max_accumulator!(self, u8, UInt8Type), - UInt16 => instantiate_max_accumulator!(self, u16, UInt16Type), - UInt32 => instantiate_max_accumulator!(self, u32, UInt32Type), - UInt64 => instantiate_max_accumulator!(self, u64, UInt64Type), +// let data_type = min_max_aggregate_data_type(args.data_type); + let data_type = args.input_type; + match data_type { + Int8 => instantiate_max_accumulator!(data_type, i8, Int8Type), + Int16 => instantiate_max_accumulator!(data_type, i16, Int16Type), + Int32 => instantiate_max_accumulator!(data_type, i32, Int32Type), + Int64 => instantiate_max_accumulator!(data_type, i64, Int64Type), + UInt8 => instantiate_max_accumulator!(data_type, u8, UInt8Type), + UInt16 => instantiate_max_accumulator!(data_type, u16, UInt16Type), + UInt32 => instantiate_max_accumulator!(data_type, u32, UInt32Type), + UInt64 => instantiate_max_accumulator!(data_type, u64, UInt64Type), Float32 => { - instantiate_max_accumulator!(self, f32, Float32Type) + instantiate_max_accumulator!(data_type, f32, Float32Type) } Float64 => { - instantiate_max_accumulator!(self, f64, Float64Type) + instantiate_max_accumulator!(data_type, f64, Float64Type) } - Date32 => instantiate_max_accumulator!(self, i32, Date32Type), - Date64 => instantiate_max_accumulator!(self, i64, Date64Type), + Date32 => instantiate_max_accumulator!(data_type, i32, Date32Type), + Date64 => instantiate_max_accumulator!(data_type, i64, Date64Type), Time32(Second) => { - instantiate_max_accumulator!(self, i32, Time32SecondType) + instantiate_max_accumulator!(data_type, i32, Time32SecondType) } Time32(Millisecond) => { - instantiate_max_accumulator!(self, i32, Time32MillisecondType) + instantiate_max_accumulator!(data_type, i32, Time32MillisecondType) } Time64(Microsecond) => { - instantiate_max_accumulator!(self, i64, Time64MicrosecondType) + instantiate_max_accumulator!(data_type, i64, Time64MicrosecondType) } Time64(Nanosecond) => { - instantiate_max_accumulator!(self, i64, Time64NanosecondType) + instantiate_max_accumulator!(data_type, i64, Time64NanosecondType) } Timestamp(Second, _) => { - instantiate_max_accumulator!(self, i64, TimestampSecondType) + instantiate_max_accumulator!(data_type, i64, TimestampSecondType) } Timestamp(Millisecond, _) => { - instantiate_max_accumulator!(self, i64, TimestampMillisecondType) + instantiate_max_accumulator!(data_type, i64, TimestampMillisecondType) } Timestamp(Microsecond, _) => { - instantiate_max_accumulator!(self, i64, TimestampMicrosecondType) + instantiate_max_accumulator!(data_type, i64, TimestampMicrosecondType) } Timestamp(Nanosecond, _) => { - instantiate_max_accumulator!(self, i64, TimestampNanosecondType) + instantiate_max_accumulator!(data_type, i64, TimestampNanosecondType) } Decimal128(_, _) => { - instantiate_max_accumulator!(self, i128, Decimal128Type) + instantiate_max_accumulator!(data_type, i128, Decimal128Type) } Decimal256(_, _) => { - instantiate_max_accumulator!(self, i256, Decimal256Type) + instantiate_max_accumulator!(data_type, i256, Decimal256Type) } // It would be nice to have a fast implementation for Strings as well // https://github.com/apache/datafusion/issues/6906 // This is only reached if groups_accumulator_supported is out of sync - _ => internal_err!( - "GroupsAccumulator not supported for max({})", - self.data_type - ), + _ => internal_err!("GroupsAccumulator not supported for max({})", data_type), } } - fn reverse_expr(&self) -> Option<Arc<dyn AggregateExpr>> { - Some(Arc::new(self.clone())) Review Comment: You miss `reverse_udf` too. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org