jorgecarleitao commented on a change in pull request #8172:
URL: https://github.com/apache/arrow/pull/8172#discussion_r487308486
##########
File path: rust/datafusion/src/physical_plan/expressions.rs
##########
@@ -97,766 +104,712 @@ pub fn col(name: &str) -> Arc<dyn PhysicalExpr> {
/// SUM aggregate expression
#[derive(Debug)]
pub struct Sum {
+ name: String,
+ data_type: DataType,
expr: Arc<dyn PhysicalExpr>,
+ nullable: bool,
}
-impl Sum {
- /// Create a new SUM aggregate function
- pub fn new(expr: Arc<dyn PhysicalExpr>) -> Self {
- Self { expr }
+/// function return type of a sum
+pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> {
+ match arg_type {
+ DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64
=> {
+ Ok(DataType::Int64)
+ }
+ DataType::UInt8 | DataType::UInt16 | DataType::UInt32 |
DataType::UInt64 => {
+ Ok(DataType::UInt64)
+ }
+ DataType::Float32 => Ok(DataType::Float32),
+ DataType::Float64 => Ok(DataType::Float64),
+ other => Err(ExecutionError::General(format!(
+ "SUM does not support type \"{:?}\"",
+ other
+ ))),
}
}
-impl AggregateExpr for Sum {
- fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
- match self.expr.data_type(input_schema)? {
- DataType::Int8 | DataType::Int16 | DataType::Int32 |
DataType::Int64 => {
- Ok(DataType::Int64)
- }
- DataType::UInt8 | DataType::UInt16 | DataType::UInt32 |
DataType::UInt64 => {
- Ok(DataType::UInt64)
- }
- DataType::Float32 => Ok(DataType::Float32),
- DataType::Float64 => Ok(DataType::Float64),
- other => Err(ExecutionError::General(format!(
- "SUM does not support {:?}",
- other
- ))),
+impl Sum {
+ /// Create a new SUM aggregate function
+ pub fn new(expr: Arc<dyn PhysicalExpr>, name: String, data_type: DataType)
-> Self {
+ Self {
+ name,
+ expr,
+ data_type,
+ nullable: true,
}
}
+}
- fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
- // null should be returned if no rows are aggregated
- Ok(true)
+impl AggregateExpr for Sum {
+ fn field(&self) -> Result<Field> {
+ Ok(Field::new(
+ &self.name,
+ self.data_type.clone(),
+ self.nullable,
+ ))
}
- fn evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef> {
- self.expr.evaluate(batch)
+ fn state_fields(&self) -> Result<Vec<Field>> {
+ Ok(vec![Field::new(
+ &format_state_name(&self.name, "sum"),
+ self.data_type.clone(),
+ self.nullable,
+ )])
}
- fn create_accumulator(&self) -> Rc<RefCell<dyn Accumulator>> {
- Rc::new(RefCell::new(SumAccumulator { sum: None }))
+ fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+ vec![self.expr.clone()]
}
- fn create_reducer(&self, column_name: &str) -> Arc<dyn AggregateExpr> {
- Arc::new(Sum::new(Arc::new(Column::new(column_name))))
+ fn create_accumulator(&self) -> Result<Rc<RefCell<dyn Accumulator>>> {
+ Ok(Rc::new(RefCell::new(SumAccumulator::try_new(
+ &self.data_type,
+ )?)))
}
}
-macro_rules! sum_accumulate {
Review comment:
This was operating on a row-by-row basis, which was replaced by a batch
operation using `compute::sum`.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]