This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 230c68c02b Add `simplify` method to aggregate function (#10354)
230c68c02b is described below
commit 230c68c02bf0c3d5b7d50d24145eb50604420d4f
Author: Marko Milenković <[email protected]>
AuthorDate: Mon May 13 12:53:56 2024 +0100
Add `simplify` method to aggregate function (#10354)
* add simplify method for aggregate function
* simplify returns closure
---
.../examples/simplify_udaf_expression.rs | 180 +++++++++++++++++++++
datafusion/expr/src/function.rs | 13 ++
datafusion/expr/src/udaf.rs | 33 +++-
.../src/simplify_expressions/expr_simplifier.rs | 105 +++++++++++-
4 files changed, 328 insertions(+), 3 deletions(-)
diff --git a/datafusion-examples/examples/simplify_udaf_expression.rs
b/datafusion-examples/examples/simplify_udaf_expression.rs
new file mode 100644
index 0000000000..92deb20272
--- /dev/null
+++ b/datafusion-examples/examples/simplify_udaf_expression.rs
@@ -0,0 +1,180 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow_schema::{Field, Schema};
+use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
+use datafusion_expr::function::AggregateFunctionSimplification;
+use datafusion_expr::simplify::SimplifyInfo;
+
+use std::{any::Any, sync::Arc};
+
+use datafusion::arrow::{array::Float32Array, record_batch::RecordBatch};
+use datafusion::error::Result;
+use datafusion::{assert_batches_eq, prelude::*};
+use datafusion_common::cast::as_float64_array;
+use datafusion_expr::{
+ expr::{AggregateFunction, AggregateFunctionDefinition},
+ function::AccumulatorArgs,
+ Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
+};
+
+/// This example shows how to use the AggregateUDFImpl::simplify API to
simplify/replace user
+/// defined aggregate function with a different expression which is defined in
the `simplify` method.
+
+#[derive(Debug, Clone)]
+struct BetterAvgUdaf {
+ signature: Signature,
+}
+
+impl BetterAvgUdaf {
+ /// Create a new instance of the GeoMeanUdaf struct
+ fn new() -> Self {
+ Self {
+ signature: Signature::exact(vec![DataType::Float64],
Volatility::Immutable),
+ }
+ }
+}
+
+impl AggregateUDFImpl for BetterAvgUdaf {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn name(&self) -> &str {
+ "better_avg"
+ }
+
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+
+ fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+ Ok(DataType::Float64)
+ }
+
+ fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn
Accumulator>> {
+ unimplemented!("should not be invoked")
+ }
+
+ fn state_fields(
+ &self,
+ _name: &str,
+ _value_type: DataType,
+ _ordering_fields: Vec<arrow_schema::Field>,
+ ) -> Result<Vec<arrow_schema::Field>> {
+ unimplemented!("should not be invoked")
+ }
+
+ fn groups_accumulator_supported(&self) -> bool {
+ true
+ }
+
+ fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
+ unimplemented!("should not get here");
+ }
+ // we override method, to return new expression which would substitute
+ // user defined function call
+ fn simplify(&self) -> Option<AggregateFunctionSimplification> {
+ // as an example for this functionality we replace UDF function
+ // with build-in aggregate function to illustrate the use
+ let simplify = |aggregate_function:
datafusion_expr::expr::AggregateFunction,
+ _: &dyn SimplifyInfo| {
+ Ok(Expr::AggregateFunction(AggregateFunction {
+ func_def: AggregateFunctionDefinition::BuiltIn(
+ // yes it is the same Avg, `BetterAvgUdaf` was just a
+ // marketing pitch :)
+
datafusion_expr::aggregate_function::AggregateFunction::Avg,
+ ),
+ args: aggregate_function.args,
+ distinct: aggregate_function.distinct,
+ filter: aggregate_function.filter,
+ order_by: aggregate_function.order_by,
+ null_treatment: aggregate_function.null_treatment,
+ }))
+ };
+
+ Some(Box::new(simplify))
+ }
+}
+
+// create local session context with an in-memory table
+fn create_context() -> Result<SessionContext> {
+ use datafusion::datasource::MemTable;
+ // define a schema.
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("a", DataType::Float32, false),
+ Field::new("b", DataType::Float32, false),
+ ]));
+
+ // define data in two partitions
+ let batch1 = RecordBatch::try_new(
+ schema.clone(),
+ vec![
+ Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])),
+ Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])),
+ ],
+ )?;
+ let batch2 = RecordBatch::try_new(
+ schema.clone(),
+ vec![
+ Arc::new(Float32Array::from(vec![16.0])),
+ Arc::new(Float32Array::from(vec![2.0])),
+ ],
+ )?;
+
+ let ctx = SessionContext::new();
+
+ // declare a table in memory. In spark API, this corresponds to
createDataFrame(...).
+ let provider = MemTable::try_new(schema, vec![vec![batch1],
vec![batch2]])?;
+ ctx.register_table("t", Arc::new(provider))?;
+ Ok(ctx)
+}
+
+#[tokio::main]
+async fn main() -> Result<()> {
+ let ctx = create_context()?;
+
+ let better_avg = AggregateUDF::from(BetterAvgUdaf::new());
+ ctx.register_udaf(better_avg.clone());
+
+ let result = ctx
+ .sql("SELECT better_avg(a) FROM t group by b")
+ .await?
+ .collect()
+ .await?;
+
+ let expected = [
+ "+-----------------+",
+ "| better_avg(t.a) |",
+ "+-----------------+",
+ "| 7.5 |",
+ "+-----------------+",
+ ];
+
+ assert_batches_eq!(expected, &result);
+
+ let df = ctx.table("t").await?;
+ let df = df.aggregate(vec![], vec![better_avg.call(vec![col("a")])])?;
+
+ let results = df.collect().await?;
+ let result = as_float64_array(results[0].column(0))?;
+
+ assert!((result.value(0) - 7.5).abs() < f64::EPSILON);
+ println!("The average of [2,4,8,16] is {}", result.value(0));
+
+ Ok(())
+}
diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs
index 7a92a50ae1..4e4d77924a 100644
--- a/datafusion/expr/src/function.rs
+++ b/datafusion/expr/src/function.rs
@@ -97,3 +97,16 @@ pub type PartitionEvaluatorFactory =
/// its state, given its return datatype.
pub type StateTypeFunction =
Arc<dyn Fn(&DataType) -> Result<Arc<Vec<DataType>>> + Send + Sync>;
+
+/// [crate::udaf::AggregateUDFImpl::simplify] simplifier closure
+/// A closure with two arguments:
+/// * 'aggregate_function': [crate::expr::AggregateFunction] for which
simplified has been invoked
+/// * 'info': [crate::simplify::SimplifyInfo]
+///
+/// closure returns simplified [Expr] or an error.
+pub type AggregateFunctionSimplification = Box<
+ dyn Fn(
+ crate::expr::AggregateFunction,
+ &dyn crate::simplify::SimplifyInfo,
+ ) -> Result<Expr>,
+>;
diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs
index e5a47ddcd8..95121d78e7 100644
--- a/datafusion/expr/src/udaf.rs
+++ b/datafusion/expr/src/udaf.rs
@@ -17,7 +17,7 @@
//! [`AggregateUDF`]: User Defined Aggregate Functions
-use crate::function::AccumulatorArgs;
+use crate::function::{AccumulatorArgs, AggregateFunctionSimplification};
use crate::groups_accumulator::GroupsAccumulator;
use crate::utils::format_state_name;
use crate::{Accumulator, Expr};
@@ -199,6 +199,12 @@ impl AggregateUDF {
pub fn coerce_types(&self, _args: &[DataType]) -> Result<Vec<DataType>> {
not_impl_err!("coerce_types not implemented for {:?} yet", self.name())
}
+ /// Do the function rewrite
+ ///
+ /// See [`AggregateUDFImpl::simplify`] for more details.
+ pub fn simplify(&self) -> Option<AggregateFunctionSimplification> {
+ self.inner.simplify()
+ }
}
impl<F> From<F> for AggregateUDF
@@ -358,6 +364,31 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
fn aliases(&self) -> &[String] {
&[]
}
+
+ /// Optionally apply per-UDaF simplification / rewrite rules.
+ ///
+ /// This can be used to apply function specific simplification rules during
+ /// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default
+ /// implementation does nothing.
+ ///
+ /// Note that DataFusion handles simplifying arguments and "constant
+ /// folding" (replacing a function call with constant arguments such as
+ /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such
+ /// optimizations manually for specific UDFs.
+ ///
+ /// # Returns
+ ///
+ /// [None] if simplify is not defined or,
+ ///
+ /// Or, a closure with two arguments:
+ /// * 'aggregate_function': [crate::expr::AggregateFunction] for which
simplified has been invoked
+ /// * 'info': [crate::simplify::SimplifyInfo]
+ ///
+ /// closure returns simplified [Expr] or an error.
+ ///
+ fn simplify(&self) -> Option<AggregateFunctionSimplification> {
+ None
+ }
}
/// AggregateUDF that adds an alias to the underlying function. It is better to
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index 5122de4f09..55052542a8 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -32,7 +32,7 @@ use datafusion_common::{
tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter},
};
use datafusion_common::{internal_err, DFSchema, DataFusionError, Result,
ScalarValue};
-use datafusion_expr::expr::{InList, InSubquery};
+use datafusion_expr::expr::{AggregateFunctionDefinition, InList, InSubquery};
use datafusion_expr::simplify::ExprSimplifyResult;
use datafusion_expr::{
and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator,
Volatility,
@@ -1382,6 +1382,16 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for
Simplifier<'a, S> {
}
}
+ Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction {
+ func_def: AggregateFunctionDefinition::UDF(ref udaf),
+ ..
+ }) => match (udaf.simplify(), expr) {
+ (Some(simplify_function), Expr::AggregateFunction(af)) => {
+ Transformed::yes(simplify_function(af, info)?)
+ }
+ (_, expr) => Transformed::no(expr),
+ },
+
//
// Rules for Between
//
@@ -1748,7 +1758,9 @@ fn inlist_except(mut l1: InList, l2: InList) ->
Result<Expr> {
#[cfg(test)]
mod tests {
use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
- use datafusion_expr::{interval_arithmetic::Interval, *};
+ use datafusion_expr::{
+ function::AggregateFunctionSimplification,
interval_arithmetic::Interval, *,
+ };
use std::{
collections::HashMap,
ops::{BitAnd, BitOr, BitXor},
@@ -3698,4 +3710,93 @@ mod tests {
assert_eq!(expr, expected);
assert_eq!(num_iter, 2);
}
+ #[test]
+ fn test_simplify_udaf() {
+ let udaf =
AggregateUDF::new_from_impl(SimplifyMockUdaf::new_with_simplify());
+ let aggregate_function_expr =
+
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
+ udaf.into(),
+ vec![],
+ false,
+ None,
+ None,
+ None,
+ ));
+
+ let expected = col("result_column");
+ assert_eq!(simplify(aggregate_function_expr), expected);
+
+ let udaf =
AggregateUDF::new_from_impl(SimplifyMockUdaf::new_without_simplify());
+ let aggregate_function_expr =
+
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
+ udaf.into(),
+ vec![],
+ false,
+ None,
+ None,
+ None,
+ ));
+
+ let expected = aggregate_function_expr.clone();
+ assert_eq!(simplify(aggregate_function_expr), expected);
+ }
+
+ /// A Mock UDAF which defines `simplify` to be used in tests
+ /// related to UDAF simplification
+ #[derive(Debug, Clone)]
+ struct SimplifyMockUdaf {
+ simplify: bool,
+ }
+
+ impl SimplifyMockUdaf {
+ /// make simplify method return new expression
+ fn new_with_simplify() -> Self {
+ Self { simplify: true }
+ }
+ /// make simplify method return no change
+ fn new_without_simplify() -> Self {
+ Self { simplify: false }
+ }
+ }
+
+ impl AggregateUDFImpl for SimplifyMockUdaf {
+ fn as_any(&self) -> &dyn std::any::Any {
+ self
+ }
+
+ fn name(&self) -> &str {
+ "mock_simplify"
+ }
+
+ fn signature(&self) -> &Signature {
+ unimplemented!()
+ }
+
+ fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+ unimplemented!("not needed for tests")
+ }
+
+ fn accumulator(
+ &self,
+ _acc_args: function::AccumulatorArgs,
+ ) -> Result<Box<dyn Accumulator>> {
+ unimplemented!("not needed for tests")
+ }
+
+ fn groups_accumulator_supported(&self) -> bool {
+ unimplemented!("not needed for testing")
+ }
+
+ fn create_groups_accumulator(&self) -> Result<Box<dyn
GroupsAccumulator>> {
+ unimplemented!("not needed for testing")
+ }
+
+ fn simplify(&self) -> Option<AggregateFunctionSimplification> {
+ if self.simplify {
+ Some(Box::new(|_, _| Ok(col("result_column"))))
+ } else {
+ None
+ }
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]