alamb commented on a change in pull request #1525: URL: https://github.com/apache/arrow-datafusion/pull/1525#discussion_r780516560
########## File path: datafusion/src/scalar.rs ########## @@ -526,6 +526,282 @@ macro_rules! eq_array_primitive { } impl ScalarValue { + /// Return true if the value is numeric + pub fn is_numeric(&self) -> bool { + matches!(self, + ScalarValue::Float32(_) + | ScalarValue::Float64(_) + | ScalarValue::Decimal128(_, _, _) + | ScalarValue::Int8(_) + | ScalarValue::Int16(_) + | ScalarValue::Int32(_) + | ScalarValue::Int64(_) + | ScalarValue::UInt8(_) + | ScalarValue::UInt16(_) + | ScalarValue::UInt32(_) + | ScalarValue::UInt64(_) + ) + } + + /// Add two numeric ScalarValues + pub fn add(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> { + if !lhs.is_numeric() || !rhs.is_numeric() { + return Err(DataFusionError::Internal(format!( + "Addition only supports numeric types, \ + here has {:?} and {:?}", + lhs.get_datatype(), + rhs.get_datatype() + ))); + } + + // TODO: Finding a good way to support operation between different types without Review comment: I think we could use the expression evaluator here rather than writing a giant match statement. Something like the following ``` /// Multiply two numeric ScalarValues pub fn mul(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> { let props = ExecutionProps::new(); let const_evaluator = ConstEvaluator::new(&props); // Create a LogicalPlan::Expr representing lhs * rhs let e = Expr::Literal(lhs.clone()) * Expr::Literal(rhs.clone()); const_evaluator.evaluate_to_scalar(e) } ``` Benefits: 1. require (much less) code 2. have built in support for `Decimal` and other nice things Drawbacks 1. It might be slower at runtime without some additional special case handling of ScalarValues, but we can do that as a follow on PR ########## File path: datafusion/src/scalar.rs ########## @@ -526,6 +526,282 @@ macro_rules! eq_array_primitive { } impl ScalarValue { + /// Return true if the value is numeric + pub fn is_numeric(&self) -> bool { + matches!(self, + ScalarValue::Float32(_) + | ScalarValue::Float64(_) + | ScalarValue::Decimal128(_, _, _) + | ScalarValue::Int8(_) + | ScalarValue::Int16(_) + | ScalarValue::Int32(_) + | ScalarValue::Int64(_) + | ScalarValue::UInt8(_) + | ScalarValue::UInt16(_) + | ScalarValue::UInt32(_) + | ScalarValue::UInt64(_) + ) + } + + /// Add two numeric ScalarValues + pub fn add(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> { Review comment: I think this code is redundant with ``` pub(super) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> { ``` https://github.com/apache/arrow-datafusion/blob/9d3186693b614db57143adbd81c82a60752a8bac/datafusion/src/physical_plan/expressions/sum.rs#L266-L349 ########## File path: datafusion/src/scalar.rs ########## @@ -3081,4 +3357,209 @@ mod tests { DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".to_owned())) ); } + + macro_rules! test_scalar_op { + ($OP:ident, $LHS:expr, $LHS_TYPE:ident, $RHS:expr, $RHS_TYPE:ident, $RESULT:expr, $RESULT_TYPE:ident) => {{ + let v1 = &ScalarValue::from($LHS as $LHS_TYPE); + let v2 = &ScalarValue::from($RHS as $RHS_TYPE); + assert_eq!( + ScalarValue::$OP(v1, v2).unwrap(), + ScalarValue::from($RESULT as $RESULT_TYPE) + ); + }}; + } + + macro_rules! test_scalar_op_err { + ($OP:ident, $LHS:expr, $LHS_TYPE:ident, $RHS:expr, $RHS_TYPE:ident) => {{ + let v1 = &ScalarValue::from($LHS as $LHS_TYPE); + let v2 = &ScalarValue::from($RHS as $RHS_TYPE); + let actual = ScalarValue::add(v1, v2).is_err(); Review comment: I believe this should be referring to `$OP` rather than `add`: ```suggestion let actual = ScalarValue::$OP(v1, v2).is_err(); ``` ########## File path: datafusion/src/scalar.rs ########## @@ -526,6 +526,282 @@ macro_rules! eq_array_primitive { } impl ScalarValue { + /// Return true if the value is numeric + pub fn is_numeric(&self) -> bool { + matches!(self, + ScalarValue::Float32(_) + | ScalarValue::Float64(_) + | ScalarValue::Decimal128(_, _, _) + | ScalarValue::Int8(_) + | ScalarValue::Int16(_) + | ScalarValue::Int32(_) + | ScalarValue::Int64(_) + | ScalarValue::UInt8(_) + | ScalarValue::UInt16(_) + | ScalarValue::UInt32(_) + | ScalarValue::UInt64(_) + ) + } + + /// Add two numeric ScalarValues + pub fn add(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> { + if !lhs.is_numeric() || !rhs.is_numeric() { + return Err(DataFusionError::Internal(format!( + "Addition only supports numeric types, \ + here has {:?} and {:?}", + lhs.get_datatype(), + rhs.get_datatype() + ))); + } + + // TODO: Finding a good way to support operation between different types without + // writing a hige match block. + // TODO: Add support for decimal types + match (lhs, rhs) { + (ScalarValue::Decimal128(_, _, _), _) | + (_, ScalarValue::Decimal128(_, _, _)) => { + Err(DataFusionError::Internal( + "Addition with Decimals are not supported for now".to_string() + )) + }, + // f64 / _ + (ScalarValue::Float64(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() + f2.unwrap()))) Review comment: These `unwrap`s will panic if either argument is `None` (aka represents NULL). ########## File path: datafusion/src/physical_plan/expressions/variance.rs ########## @@ -0,0 +1,376 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use std::any::Any; +use std::sync::Arc; + +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; +use crate::scalar::ScalarValue; +use arrow::datatypes::DataType; +use arrow::datatypes::Field; + +use super::format_state_name; + +/// VARIANCE aggregate expression +#[derive(Debug)] +pub struct Variance { + name: String, + expr: Arc<dyn PhysicalExpr>, +} + +/// function return type of variance +pub fn variance_return_type(arg_type: &DataType) -> Result<DataType> { + match arg_type { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 => Ok(DataType::Float64), + other => Err(DataFusionError::Plan(format!( + "VARIANCE does not support {:?}", + other + ))), + } +} + +pub(crate) fn is_variance_support_arg_type(arg_type: &DataType) -> bool { + matches!( + arg_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + ) +} + +impl Variance { + /// Create a new VARIANCE aggregate function + pub fn new( + expr: Arc<dyn PhysicalExpr>, + name: impl Into<String>, + data_type: DataType, + ) -> Self { + // the result of variance just support FLOAT64 data type. + assert!(matches!(data_type, DataType::Float64)); + Self { + name: name.into(), + expr, + } + } +} + +impl AggregateExpr for Variance { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result<Field> { + Ok(Field::new(&self.name, DataType::Float64, true)) + } + + fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> { + Ok(Box::new(VarianceAccumulator::try_new()?)) + } + + fn state_fields(&self) -> Result<Vec<Field>> { + Ok(vec![ + Field::new( + &format_state_name(&self.name, "count"), + DataType::UInt64, + true, + ), + Field::new( + &format_state_name(&self.name, "mean"), + DataType::Float64, + true, + ), + Field::new( + &format_state_name(&self.name, "m2"), + DataType::Float64, + true, + ), + ]) + } + + fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +/// An accumulator to compute variance Review comment: I suggest bringing the algorithmic reference from the PR description into the code ```suggestion /// An accumulator to compute variance /// /// The algorithm used is an online implementation and numerically stable. It is based on this paper: /// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products". /// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577. /// /// It has been analyzed here: /// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing Sample Means and Variances". /// Journal of the American Statistical Association. 69 (348): 859–866. doi:10.2307/2286154. JSTOR 2286154. ``` ########## File path: datafusion/src/physical_plan/expressions/mod.rs ########## @@ -84,9 +86,13 @@ pub use nth_value::NthValue; pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES}; pub use rank::{dense_rank, percent_rank, rank}; pub use row_number::RowNumber; +pub(crate) use stddev::is_stddev_support_arg_type; +pub use stddev::{stddev_return_type, Stddev}; pub(crate) use sum::is_sum_support_arg_type; pub use sum::{sum_return_type, Sum}; pub use try_cast::{try_cast, TryCastExpr}; +pub(crate) use variance::is_variance_support_arg_type; +pub use variance::{variance_return_type, Variance}; Review comment: ```suggestion pub (crate) use variance::{variance_return_type, Variance}; ``` ########## File path: datafusion/src/physical_plan/expressions/stddev.rs ########## @@ -0,0 +1,312 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use std::any::Any; +use std::sync::Arc; + +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::{ + expressions::variance::VarianceAccumulator, Accumulator, AggregateExpr, PhysicalExpr, +}; +use crate::scalar::ScalarValue; +use arrow::datatypes::DataType; +use arrow::datatypes::Field; + +use super::format_state_name; + +/// STDDEV (standard deviation) aggregate expression +#[derive(Debug)] +pub struct Stddev { + name: String, + expr: Arc<dyn PhysicalExpr>, + data_type: DataType, +} + +/// function return type of standard deviation +pub fn stddev_return_type(arg_type: &DataType) -> Result<DataType> { + match arg_type { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 => Ok(DataType::Float64), + other => Err(DataFusionError::Plan(format!( + "STDDEV does not support {:?}", + other + ))), + } +} + +pub(crate) fn is_stddev_support_arg_type(arg_type: &DataType) -> bool { + matches!( + arg_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + ) +} + +impl Stddev { + /// Create a new STDDEV aggregate function + pub fn new( + expr: Arc<dyn PhysicalExpr>, + name: impl Into<String>, + data_type: DataType, + ) -> Self { + // the result of stddev just support FLOAT64 and Decimal data type. + assert!(matches!(data_type, DataType::Float64)); + Self { + name: name.into(), + expr, + data_type, + } + } +} + +impl AggregateExpr for Stddev { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result<Field> { + Ok(Field::new(&self.name, self.data_type.clone(), true)) + } + + fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> { + Ok(Box::new(StddevAccumulator::try_new()?)) + } + + fn state_fields(&self) -> Result<Vec<Field>> { + Ok(vec![ Review comment: this is wonderful ########## File path: datafusion/src/physical_plan/expressions/mod.rs ########## @@ -84,9 +86,13 @@ pub use nth_value::NthValue; pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES}; pub use rank::{dense_rank, percent_rank, rank}; pub use row_number::RowNumber; +pub(crate) use stddev::is_stddev_support_arg_type; +pub use stddev::{stddev_return_type, Stddev}; Review comment: Put another way, perhaps this could be ```suggestion pub (crate) use stddev::{stddev_return_type, Stddev}; ``` ########## File path: datafusion/src/physical_plan/expressions/variance.rs ########## @@ -0,0 +1,376 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use std::any::Any; +use std::sync::Arc; + +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; +use crate::scalar::ScalarValue; +use arrow::datatypes::DataType; +use arrow::datatypes::Field; + +use super::format_state_name; + +/// VARIANCE aggregate expression +#[derive(Debug)] +pub struct Variance { + name: String, + expr: Arc<dyn PhysicalExpr>, +} + +/// function return type of variance +pub fn variance_return_type(arg_type: &DataType) -> Result<DataType> { + match arg_type { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 => Ok(DataType::Float64), + other => Err(DataFusionError::Plan(format!( + "VARIANCE does not support {:?}", + other + ))), + } +} + +pub(crate) fn is_variance_support_arg_type(arg_type: &DataType) -> bool { + matches!( + arg_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + ) +} + +impl Variance { + /// Create a new VARIANCE aggregate function + pub fn new( + expr: Arc<dyn PhysicalExpr>, + name: impl Into<String>, + data_type: DataType, + ) -> Self { + // the result of variance just support FLOAT64 data type. + assert!(matches!(data_type, DataType::Float64)); + Self { + name: name.into(), + expr, + } + } +} + +impl AggregateExpr for Variance { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result<Field> { + Ok(Field::new(&self.name, DataType::Float64, true)) + } + + fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> { + Ok(Box::new(VarianceAccumulator::try_new()?)) + } + + fn state_fields(&self) -> Result<Vec<Field>> { + Ok(vec![ + Field::new( + &format_state_name(&self.name, "count"), + DataType::UInt64, + true, + ), + Field::new( + &format_state_name(&self.name, "mean"), + DataType::Float64, + true, + ), + Field::new( + &format_state_name(&self.name, "m2"), + DataType::Float64, + true, + ), + ]) + } + + fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +/// An accumulator to compute variance +#[derive(Debug)] +pub struct VarianceAccumulator { + m2: ScalarValue, + mean: ScalarValue, + count: u64, +} + +impl VarianceAccumulator { + /// Creates a new `VarianceAccumulator` + pub fn try_new() -> Result<Self> { + Ok(Self { + m2: ScalarValue::from(0 as f64), + mean: ScalarValue::from(0 as f64), + count: 0, + }) + } + + pub fn get_count(&self) -> u64 { + self.count + } + + pub fn get_mean(&self) -> ScalarValue { + self.mean.clone() + } + + pub fn get_m2(&self) -> ScalarValue { + self.m2.clone() + } +} + +impl Accumulator for VarianceAccumulator { + fn state(&self) -> Result<Vec<ScalarValue>> { + Ok(vec![ + ScalarValue::from(self.count), + self.mean.clone(), + self.m2.clone(), + ]) + } + + fn update(&mut self, values: &[ScalarValue]) -> Result<()> { + let values = &values[0]; + let is_empty = values.is_null(); + + if !is_empty { + let new_count = self.count + 1; + let delta1 = ScalarValue::add(values, &self.mean.arithmetic_negate())?; Review comment: Using `ScalarValue`s like this to accumulate each operation is likely to be quite slow during runtime. However, I think it would be fine to put in as a first initial implementation and then implement an optimized version using `update_batch` and arrow compute kernels as a follow on PR ```rust /// updates the accumulator's state from a vector of arrays. fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org