jorgecarleitao commented on a change in pull request #8172:
URL: https://github.com/apache/arrow/pull/8172#discussion_r487309726
##########
File path: rust/datafusion/src/physical_plan/expressions.rs
##########
@@ -97,766 +104,712 @@ pub fn col(name: &str) -> Arc<dyn PhysicalExpr> {
/// SUM aggregate expression
#[derive(Debug)]
pub struct Sum {
+ name: String,
+ data_type: DataType,
expr: Arc<dyn PhysicalExpr>,
+ nullable: bool,
}
-impl Sum {
- /// Create a new SUM aggregate function
- pub fn new(expr: Arc<dyn PhysicalExpr>) -> Self {
- Self { expr }
+/// function return type of a sum
+pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> {
+ match arg_type {
+ DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64
=> {
+ Ok(DataType::Int64)
+ }
+ DataType::UInt8 | DataType::UInt16 | DataType::UInt32 |
DataType::UInt64 => {
+ Ok(DataType::UInt64)
+ }
+ DataType::Float32 => Ok(DataType::Float32),
+ DataType::Float64 => Ok(DataType::Float64),
+ other => Err(ExecutionError::General(format!(
+ "SUM does not support type \"{:?}\"",
+ other
+ ))),
}
}
-impl AggregateExpr for Sum {
- fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
- match self.expr.data_type(input_schema)? {
- DataType::Int8 | DataType::Int16 | DataType::Int32 |
DataType::Int64 => {
- Ok(DataType::Int64)
- }
- DataType::UInt8 | DataType::UInt16 | DataType::UInt32 |
DataType::UInt64 => {
- Ok(DataType::UInt64)
- }
- DataType::Float32 => Ok(DataType::Float32),
- DataType::Float64 => Ok(DataType::Float64),
- other => Err(ExecutionError::General(format!(
- "SUM does not support {:?}",
- other
- ))),
+impl Sum {
+ /// Create a new SUM aggregate function
+ pub fn new(expr: Arc<dyn PhysicalExpr>, name: String, data_type: DataType)
-> Self {
+ Self {
+ name,
+ expr,
+ data_type,
+ nullable: true,
}
}
+}
- fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
- // null should be returned if no rows are aggregated
- Ok(true)
+impl AggregateExpr for Sum {
+ fn field(&self) -> Result<Field> {
+ Ok(Field::new(
+ &self.name,
+ self.data_type.clone(),
+ self.nullable,
+ ))
}
- fn evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef> {
- self.expr.evaluate(batch)
+ fn state_fields(&self) -> Result<Vec<Field>> {
+ Ok(vec![Field::new(
+ &format_state_name(&self.name, "sum"),
+ self.data_type.clone(),
+ self.nullable,
+ )])
}
- fn create_accumulator(&self) -> Rc<RefCell<dyn Accumulator>> {
- Rc::new(RefCell::new(SumAccumulator { sum: None }))
+ fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+ vec![self.expr.clone()]
}
- fn create_reducer(&self, column_name: &str) -> Arc<dyn AggregateExpr> {
- Arc::new(Sum::new(Arc::new(Column::new(column_name))))
+ fn create_accumulator(&self) -> Result<Rc<RefCell<dyn Accumulator>>> {
+ Ok(Rc::new(RefCell::new(SumAccumulator::try_new(
+ &self.data_type,
+ )?)))
}
}
-macro_rules! sum_accumulate {
- ($SELF:ident, $VALUE:expr, $ARRAY_TYPE:ident, $SCALAR_VARIANT:ident,
$TY:ty) => {{
- $SELF.sum = match $SELF.sum {
- Some(ScalarValue::$SCALAR_VARIANT(n)) => {
- Some(ScalarValue::$SCALAR_VARIANT(n + $VALUE as $TY))
- }
- Some(_) => {
- return Err(ExecutionError::InternalError(
- "Unexpected ScalarValue variant".to_string(),
- ))
- }
- None => Some(ScalarValue::$SCALAR_VARIANT($VALUE as $TY)),
- };
- }};
-}
-
#[derive(Debug)]
struct SumAccumulator {
- sum: Option<ScalarValue>,
+ sum: ScalarValue,
}
-impl Accumulator for SumAccumulator {
- fn accumulate_scalar(&mut self, value: Option<ScalarValue>) -> Result<()> {
- if let Some(value) = value {
- match value {
- ScalarValue::Int8(value) => {
- sum_accumulate!(self, value, Int8Array, Int64, i64);
- }
- ScalarValue::Int16(value) => {
- sum_accumulate!(self, value, Int16Array, Int64, i64);
- }
- ScalarValue::Int32(value) => {
- sum_accumulate!(self, value, Int32Array, Int64, i64);
- }
- ScalarValue::Int64(value) => {
- sum_accumulate!(self, value, Int64Array, Int64, i64);
- }
- ScalarValue::UInt8(value) => {
- sum_accumulate!(self, value, UInt8Array, UInt64, u64);
- }
- ScalarValue::UInt16(value) => {
- sum_accumulate!(self, value, UInt16Array, UInt64, u64);
- }
- ScalarValue::UInt32(value) => {
- sum_accumulate!(self, value, UInt32Array, UInt64, u64);
- }
- ScalarValue::UInt64(value) => {
- sum_accumulate!(self, value, UInt64Array, UInt64, u64);
- }
- ScalarValue::Float32(value) => {
- sum_accumulate!(self, value, Float32Array, Float32, f32);
- }
- ScalarValue::Float64(value) => {
- sum_accumulate!(self, value, Float64Array, Float64, f64);
- }
- other => {
- return Err(ExecutionError::General(format!(
- "SUM does not support {:?}",
- other
- )))
- }
- }
- }
- Ok(())
+impl SumAccumulator {
+ /// new sum accumulator
+ pub fn try_new(data_type: &DataType) -> Result<Self> {
+ Ok(Self {
+ sum: ScalarValue::try_from(data_type)?,
+ })
}
+}
- fn accumulate_batch(&mut self, array: &ArrayRef) -> Result<()> {
- let sum = match array.data_type() {
- DataType::UInt8 => {
- match
compute::sum(array.as_any().downcast_ref::<UInt8Array>().unwrap()) {
- Some(n) => Ok(Some(ScalarValue::UInt8(n))),
- None => Ok(None),
- }
+// returns the new value after sum with the new values, taking nullability
into account
+macro_rules! typed_sum_accumulate {
+ ($OLD_VALUE:expr, $NEW_VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident,
$TYPE:ident) => {{
+ let array = $NEW_VALUES.as_any().downcast_ref::<$ARRAYTYPE>().unwrap();
+ let delta = compute::sum(array);
+ if $OLD_VALUE.is_none() {
+ ScalarValue::$SCALAR(delta.and_then(|e| Some(e as $TYPE)))
+ } else {
+ let delta = delta.and_then(|e| Some(e as $TYPE)).unwrap_or(0 as
$TYPE);
+ ScalarValue::from($OLD_VALUE.unwrap() + delta)
+ }
+ }};
+}
+
+// given an existing value `old` and an `array` of new values,
+// performs a sum, returning the new value.
+fn sum_accumulate(old: &ScalarValue, array: &ArrayRef) -> Result<ScalarValue> {
+ Ok(match old {
+ ScalarValue::Float64(sum) => match array.data_type() {
+ DataType::Float64 => {
+ typed_sum_accumulate!(sum, array, Float64Array, Float64, f64)
}
- DataType::UInt16 => {
- match
compute::sum(array.as_any().downcast_ref::<UInt16Array>().unwrap())
- {
- Some(n) => Ok(Some(ScalarValue::UInt16(n))),
- None => Ok(None),
- }
+ DataType::Float32 => {
+ typed_sum_accumulate!(sum, array, Float32Array, Float64, f64)
}
- DataType::UInt32 => {
- match
compute::sum(array.as_any().downcast_ref::<UInt32Array>().unwrap())
- {
- Some(n) => Ok(Some(ScalarValue::UInt32(n))),
- None => Ok(None),
- }
+ DataType::Int64 => {
+ typed_sum_accumulate!(sum, array, Int64Array, Float64, f64)
+ }
+ DataType::Int32 => {
+ typed_sum_accumulate!(sum, array, Int32Array, Float64, f64)
+ }
+ DataType::Int16 => {
+ typed_sum_accumulate!(sum, array, Int16Array, Float64, f64)
}
+ DataType::Int8 => typed_sum_accumulate!(sum, array, Int8Array,
Float64, f64),
DataType::UInt64 => {
- match
compute::sum(array.as_any().downcast_ref::<UInt64Array>().unwrap())
- {
- Some(n) => Ok(Some(ScalarValue::UInt64(n))),
- None => Ok(None),
- }
+ typed_sum_accumulate!(sum, array, UInt64Array, Float64, f64)
}
- DataType::Int8 => {
- match
compute::sum(array.as_any().downcast_ref::<Int8Array>().unwrap()) {
- Some(n) => Ok(Some(ScalarValue::Int8(n))),
- None => Ok(None),
- }
+ DataType::UInt32 => {
+ typed_sum_accumulate!(sum, array, UInt32Array, Float64, f64)
}
- DataType::Int16 => {
- match
compute::sum(array.as_any().downcast_ref::<Int16Array>().unwrap()) {
- Some(n) => Ok(Some(ScalarValue::Int16(n))),
- None => Ok(None),
- }
+ DataType::UInt16 => {
+ typed_sum_accumulate!(sum, array, UInt16Array, Float64, f64)
}
- DataType::Int32 => {
- match
compute::sum(array.as_any().downcast_ref::<Int32Array>().unwrap()) {
- Some(n) => Ok(Some(ScalarValue::Int32(n))),
- None => Ok(None),
- }
+ DataType::UInt8 => {
+ typed_sum_accumulate!(sum, array, UInt8Array, Float64, f64)
}
- DataType::Int64 => {
- match
compute::sum(array.as_any().downcast_ref::<Int64Array>().unwrap()) {
- Some(n) => Ok(Some(ScalarValue::Int64(n))),
- None => Ok(None),
- }
+ dt => {
+ return Err(ExecutionError::InternalError(format!(
+ "Sum f64 does not expect to receive type {:?}",
+ dt
+ )))
}
- DataType::Float32 => {
- match
compute::sum(array.as_any().downcast_ref::<Float32Array>().unwrap())
- {
- Some(n) => Ok(Some(ScalarValue::Float32(n))),
- None => Ok(None),
- }
+ },
+ ScalarValue::Float32(sum) => {
+ typed_sum_accumulate!(sum, array, Float32Array, Float32, f32)
+ }
+ ScalarValue::UInt64(sum) => match array.data_type() {
+ DataType::UInt64 => {
+ typed_sum_accumulate!(sum, array, UInt64Array, UInt64, u64)
}
- DataType::Float64 => {
- match
compute::sum(array.as_any().downcast_ref::<Float64Array>().unwrap())
- {
- Some(n) => Ok(Some(ScalarValue::Float64(n))),
- None => Ok(None),
- }
+ DataType::UInt32 => {
+ typed_sum_accumulate!(sum, array, UInt32Array, UInt64, u64)
}
- _ => Err(ExecutionError::ExecutionError(
- "Unsupported data type for SUM".to_string(),
- )),
- }?;
- self.accumulate_scalar(sum)
+ DataType::UInt16 => {
+ typed_sum_accumulate!(sum, array, UInt16Array, UInt64, u64)
+ }
+ DataType::UInt8 => typed_sum_accumulate!(sum, array, UInt8Array,
UInt64, u64),
+ dt => {
+ return Err(ExecutionError::InternalError(format!(
+ "Sum is not expected to receive type {:?}",
+ dt
+ )))
+ }
+ },
+ ScalarValue::Int64(sum) => match array.data_type() {
+ DataType::Int64 => typed_sum_accumulate!(sum, array, Int64Array,
Int64, i64),
+ DataType::Int32 => typed_sum_accumulate!(sum, array, Int32Array,
Int64, i64),
+ DataType::Int16 => typed_sum_accumulate!(sum, array, Int16Array,
Int64, i64),
+ DataType::Int8 => typed_sum_accumulate!(sum, array, Int8Array,
Int64, i64),
+ dt => {
+ return Err(ExecutionError::InternalError(format!(
+ "Sum is not expected to receive type {:?}",
+ dt
+ )))
+ }
+ },
+ e => {
+ return Err(ExecutionError::InternalError(format!(
+ "Sum is not expected to receive a scalar {:?}",
+ e
+ )))
+ }
+ })
+}
+
+impl Accumulator for SumAccumulator {
+ fn update(&mut self, values: &Vec<ArrayRef>) -> Result<()> {
+ // sum(v1, v2, v3) = v1 + v2 + v3
+ self.sum = sum_accumulate(&self.sum, &values[0])?;
+ Ok(())
}
- fn get_value(&self) -> Result<Option<ScalarValue>> {
- Ok(self.sum.clone())
+ fn merge(&mut self, states: &Vec<ArrayRef>) -> Result<()> {
+ let state = &states[0];
+ // sum(sum1, sum2, sum3) = sum1 + sum2 + sum3
+ self.sum = sum_accumulate(&self.sum, state)?;
+ Ok(())
}
-}
-/// Create a sum expression
-pub fn sum(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn AggregateExpr> {
- Arc::new(Sum::new(expr))
+ fn state(&self) -> Result<Vec<ScalarValue>> {
+ Ok(vec![self.sum.clone()])
+ }
+
+ fn value(&self) -> Result<ScalarValue> {
+ Ok(self.sum.clone())
+ }
}
/// AVG aggregate expression
#[derive(Debug)]
pub struct Avg {
+ name: String,
+ data_type: DataType,
+ nullable: bool,
expr: Arc<dyn PhysicalExpr>,
}
-impl Avg {
- /// Create a new AVG aggregate function
- pub fn new(expr: Arc<dyn PhysicalExpr>) -> Self {
- Self { expr }
+/// function return type of an average
+pub fn avg_return_type(arg_type: &DataType) -> Result<DataType> {
+ match arg_type {
+ DataType::Int8
+ | DataType::Int16
+ | DataType::Int32
+ | DataType::Int64
+ | DataType::UInt8
+ | DataType::UInt16
+ | DataType::UInt32
+ | DataType::UInt64
+ | DataType::Float32
+ | DataType::Float64 => Ok(DataType::Float64),
+ other => Err(ExecutionError::General(format!(
+ "AVG does not support {:?}",
+ other
+ ))),
}
}
-impl AggregateExpr for Avg {
- fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
- match self.expr.data_type(input_schema)? {
- DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- | DataType::UInt8
- | DataType::UInt16
- | DataType::UInt32
- | DataType::UInt64
- | DataType::Float32
- | DataType::Float64 => Ok(DataType::Float64),
- other => Err(ExecutionError::General(format!(
- "AVG does not support {:?}",
- other
- ))),
+impl Avg {
+ /// Create a new AVG aggregate function
+ pub fn new(expr: Arc<dyn PhysicalExpr>, name: String, data_type: DataType)
-> Self {
+ Self {
+ name,
+ expr,
+ data_type,
+ nullable: true,
}
}
+}
- fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
- // null should be returned if no rows are aggregated
- Ok(true)
+impl AggregateExpr for Avg {
+ fn field(&self) -> Result<Field> {
+ Ok(Field::new(&self.name, DataType::Float64, true))
}
- fn evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef> {
- self.expr.evaluate(batch)
+ fn state_fields(&self) -> Result<Vec<Field>> {
+ Ok(vec![
+ Field::new(
+ &format_state_name(&self.name, "count"),
+ DataType::UInt64,
+ true,
+ ),
+ Field::new(
+ &format_state_name(&self.name, "sum"),
+ DataType::Float64,
+ true,
+ ),
+ ])
}
- fn create_accumulator(&self) -> Rc<RefCell<dyn Accumulator>> {
- Rc::new(RefCell::new(AvgAccumulator {
- sum: None,
- count: None,
- }))
+ fn create_accumulator(&self) -> Result<Rc<RefCell<dyn Accumulator>>> {
+ Ok(Rc::new(RefCell::new(AvgAccumulator::try_new(
+ // avg is f64
+ &DataType::Float64,
+ )?)))
}
- fn create_reducer(&self, column_name: &str) -> Arc<dyn AggregateExpr> {
- Arc::new(Avg::new(Arc::new(Column::new(column_name))))
+ fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+ vec![self.expr.clone()]
}
}
-macro_rules! avg_accumulate {
- ($SELF:ident, $VALUE:expr, $ARRAY_TYPE:ident) => {{
- match ($SELF.sum, $SELF.count) {
- (Some(sum), Some(count)) => {
- $SELF.sum = Some(sum + $VALUE as f64);
- $SELF.count = Some(count + 1);
- }
- _ => {
- $SELF.sum = Some($VALUE as f64);
- $SELF.count = Some(1);
- }
- };
- }};
-}
#[derive(Debug)]
struct AvgAccumulator {
- sum: Option<f64>,
- count: Option<i64>,
+ // sum is used for null
+ sum: ScalarValue,
+ count: u64,
+}
+
+impl AvgAccumulator {
+ pub fn try_new(datatype: &DataType) -> Result<Self> {
+ Ok(Self {
+ sum: ScalarValue::try_from(datatype)?,
+ count: 0,
+ })
+ }
}
impl Accumulator for AvgAccumulator {
- fn accumulate_scalar(&mut self, value: Option<ScalarValue>) -> Result<()> {
- if let Some(value) = value {
- match value {
- ScalarValue::Int8(value) => avg_accumulate!(self, value,
Int8Array),
- ScalarValue::Int16(value) => avg_accumulate!(self, value,
Int16Array),
- ScalarValue::Int32(value) => avg_accumulate!(self, value,
Int32Array),
- ScalarValue::Int64(value) => avg_accumulate!(self, value,
Int64Array),
- ScalarValue::UInt8(value) => avg_accumulate!(self, value,
UInt8Array),
- ScalarValue::UInt16(value) => avg_accumulate!(self, value,
UInt16Array),
- ScalarValue::UInt32(value) => avg_accumulate!(self, value,
UInt32Array),
- ScalarValue::UInt64(value) => avg_accumulate!(self, value,
UInt64Array),
- ScalarValue::Float32(value) => avg_accumulate!(self, value,
Float32Array),
- ScalarValue::Float64(value) => avg_accumulate!(self, value,
Float64Array),
- other => {
- return Err(ExecutionError::General(format!(
- "AVG does not support {:?}",
- other
- )))
- }
- }
- }
+ fn update(&mut self, values: &Vec<ArrayRef>) -> Result<()> {
+ let values = &values[0];
+
+ self.count += (values.len() - values.data().null_count()) as u64;
+ self.sum = sum_accumulate(&self.sum, values)?;
Ok(())
}
- fn accumulate_batch(&mut self, array: &ArrayRef) -> Result<()> {
- for row in 0..array.len() {
- self.accumulate_scalar(get_scalar_value(array, row)?)?;
- }
+ fn merge(&mut self, states: &Vec<ArrayRef>) -> Result<()> {
Review comment:
This is the prime example of this PR: the merge here uses two states to
change two states from the accumulator.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]