This is an automated email from the ASF dual-hosted git repository.
xudong963 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 933fec845e Consolidate example: simplify_udaf_expression.rs into
advanced_udaf.rs (#13905)
933fec845e is described below
commit 933fec845e02bb8983a0b932cfa12ebb27054748
Author: Takahiro Ebato <[email protected]>
AuthorDate: Fri Dec 27 15:32:56 2024 +0900
Consolidate example: simplify_udaf_expression.rs into advanced_udaf.rs
(#13905)
---
datafusion-examples/examples/advanced_udaf.rs | 185 +++++++++++++++------
.../examples/simplify_udaf_expression.rs | 176 --------------------
2 files changed, 132 insertions(+), 229 deletions(-)
diff --git a/datafusion-examples/examples/advanced_udaf.rs
b/datafusion-examples/examples/advanced_udaf.rs
index 414596bdc6..a914cea4a9 100644
--- a/datafusion-examples/examples/advanced_udaf.rs
+++ b/datafusion-examples/examples/advanced_udaf.rs
@@ -31,7 +31,9 @@ use datafusion::error::Result;
use datafusion::prelude::*;
use datafusion_common::{cast::as_float64_array, ScalarValue};
use datafusion_expr::{
- function::{AccumulatorArgs, StateFieldsArgs},
+ expr::AggregateFunction,
+ function::{AccumulatorArgs, AggregateFunctionSimplification,
StateFieldsArgs},
+ simplify::SimplifyInfo,
Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
};
@@ -197,40 +199,6 @@ impl Accumulator for GeometricMean {
}
}
-// create local session context with an in-memory table
-fn create_context() -> Result<SessionContext> {
- use datafusion::datasource::MemTable;
- // define a schema.
- let schema = Arc::new(Schema::new(vec![
- Field::new("a", DataType::Float32, false),
- Field::new("b", DataType::Float32, false),
- ]));
-
- // define data in two partitions
- let batch1 = RecordBatch::try_new(
- schema.clone(),
- vec![
- Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])),
- Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])),
- ],
- )?;
- let batch2 = RecordBatch::try_new(
- schema.clone(),
- vec![
- Arc::new(Float32Array::from(vec![64.0])),
- Arc::new(Float32Array::from(vec![2.0])),
- ],
- )?;
-
- // declare a new context. In spark API, this corresponds to a new spark
SQLsession
- let ctx = SessionContext::new();
-
- // declare a table in memory. In spark API, this corresponds to
createDataFrame(...).
- let provider = MemTable::try_new(schema, vec![vec![batch1],
vec![batch2]])?;
- ctx.register_table("t", Arc::new(provider))?;
- Ok(ctx)
-}
-
// Define a `GroupsAccumulator` for GeometricMean
/// which handles accumulator state for multiple groups at once.
/// This API is significantly more complicated than `Accumulator`, which
manages
@@ -399,35 +367,146 @@ impl GroupsAccumulator for
GeometricMeanGroupsAccumulator {
}
}
+/// This example shows how to use the AggregateUDFImpl::simplify API to
simplify/replace user
+/// defined aggregate function with a different expression which is defined in
the `simplify` method.
+#[derive(Debug, Clone)]
+struct SimplifiedGeoMeanUdaf {
+ signature: Signature,
+}
+
+impl SimplifiedGeoMeanUdaf {
+ fn new() -> Self {
+ Self {
+ signature: Signature::exact(vec![DataType::Float64],
Volatility::Immutable),
+ }
+ }
+}
+
+impl AggregateUDFImpl for SimplifiedGeoMeanUdaf {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn name(&self) -> &str {
+ "simplified_geo_mean"
+ }
+
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+
+ fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+ Ok(DataType::Float64)
+ }
+
+ fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn
Accumulator>> {
+ unimplemented!("should not be invoked")
+ }
+
+ fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
+ unimplemented!("should not be invoked")
+ }
+
+ fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
+ true
+ }
+
+ fn create_groups_accumulator(
+ &self,
+ _args: AccumulatorArgs,
+ ) -> Result<Box<dyn GroupsAccumulator>> {
+ unimplemented!("should not get here");
+ }
+
+ /// Optionally replaces a UDAF with another expression during query
optimization.
+ fn simplify(&self) -> Option<AggregateFunctionSimplification> {
+ let simplify = |aggregate_function: AggregateFunction, _: &dyn
SimplifyInfo| {
+ // Replaces the UDAF with `GeoMeanUdaf` as a placeholder example
to demonstrate the `simplify` method.
+ // In real-world scenarios, you might create UDFs from built-in
expressions.
+ Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
+ Arc::new(AggregateUDF::from(GeoMeanUdaf::new())),
+ aggregate_function.args,
+ aggregate_function.distinct,
+ aggregate_function.filter,
+ aggregate_function.order_by,
+ aggregate_function.null_treatment,
+ )))
+ };
+ Some(Box::new(simplify))
+ }
+}
+
+// create local session context with an in-memory table
+fn create_context() -> Result<SessionContext> {
+ use datafusion::datasource::MemTable;
+ // define a schema.
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("a", DataType::Float32, false),
+ Field::new("b", DataType::Float32, false),
+ ]));
+
+ // define data in two partitions
+ let batch1 = RecordBatch::try_new(
+ schema.clone(),
+ vec![
+ Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])),
+ Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])),
+ ],
+ )?;
+ let batch2 = RecordBatch::try_new(
+ schema.clone(),
+ vec![
+ Arc::new(Float32Array::from(vec![64.0])),
+ Arc::new(Float32Array::from(vec![2.0])),
+ ],
+ )?;
+
+ // declare a new context. In spark API, this corresponds to a new spark
SQLsession
+ let ctx = SessionContext::new();
+
+ // declare a table in memory. In spark API, this corresponds to
createDataFrame(...).
+ let provider = MemTable::try_new(schema, vec![vec![batch1],
vec![batch2]])?;
+ ctx.register_table("t", Arc::new(provider))?;
+ Ok(ctx)
+}
+
#[tokio::main]
async fn main() -> Result<()> {
let ctx = create_context()?;
- // create the AggregateUDF
- let geometric_mean = AggregateUDF::from(GeoMeanUdaf::new());
- ctx.register_udaf(geometric_mean.clone());
+ let geo_mean_udf = AggregateUDF::from(GeoMeanUdaf::new());
+ let simplified_geo_mean_udf =
AggregateUDF::from(SimplifiedGeoMeanUdaf::new());
+
+ for (udf, udf_name) in [
+ (geo_mean_udf, "geo_mean"),
+ (simplified_geo_mean_udf, "simplified_geo_mean"),
+ ] {
+ ctx.register_udaf(udf.clone());
- let sql_df = ctx.sql("SELECT geo_mean(a) FROM t group by b").await?;
- sql_df.show().await?;
+ let sql_df = ctx
+ .sql(&format!("SELECT {}(a) FROM t GROUP BY b", udf_name))
+ .await?;
+ sql_df.show().await?;
- // get a DataFrame from the context
- // this table has 1 column `a` f32 with values {2,4,8,64}, whose geometric
mean is 8.0.
- let df = ctx.table("t").await?;
+ // get a DataFrame from the context
+ // this table has 1 column `a` f32 with values {2,4,8,64}, whose
geometric mean is 8.0.
+ let df = ctx.table("t").await?;
- // perform the aggregation
- let df = df.aggregate(vec![], vec![geometric_mean.call(vec![col("a")])])?;
+ // perform the aggregation
+ let df = df.aggregate(vec![], vec![udf.call(vec![col("a")])])?;
- // note that "a" is f32, not f64. DataFusion coerces it to match the
UDAF's signature.
+ // note that "a" is f32, not f64. DataFusion coerces it to match the
UDAF's signature.
- // execute the query
- let results = df.collect().await?;
+ // execute the query
+ let results = df.collect().await?;
- // downcast the array to the expected type
- let result = as_float64_array(results[0].column(0))?;
+ // downcast the array to the expected type
+ let result = as_float64_array(results[0].column(0))?;
- // verify that the calculation is correct
- assert!((result.value(0) - 8.0).abs() < f64::EPSILON);
- println!("The geometric mean of [2,4,8,64] is {}", result.value(0));
+ // verify that the calculation is correct
+ assert!((result.value(0) - 8.0).abs() < f64::EPSILON);
+ println!("The geometric mean of [2,4,8,64] is {}", result.value(0));
+ }
Ok(())
}
diff --git a/datafusion-examples/examples/simplify_udaf_expression.rs
b/datafusion-examples/examples/simplify_udaf_expression.rs
deleted file mode 100644
index 52a27317e3..0000000000
--- a/datafusion-examples/examples/simplify_udaf_expression.rs
+++ /dev/null
@@ -1,176 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-use std::{any::Any, sync::Arc};
-
-use arrow_schema::{Field, Schema};
-
-use datafusion::arrow::{array::Float32Array, record_batch::RecordBatch};
-use datafusion::error::Result;
-use datafusion::functions_aggregate::average::avg_udaf;
-use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
-use datafusion::{assert_batches_eq, prelude::*};
-use datafusion_common::cast::as_float64_array;
-use datafusion_expr::function::{AggregateFunctionSimplification,
StateFieldsArgs};
-use datafusion_expr::simplify::SimplifyInfo;
-use datafusion_expr::{
- expr::AggregateFunction, function::AccumulatorArgs, Accumulator,
AggregateUDF,
- AggregateUDFImpl, GroupsAccumulator, Signature,
-};
-
-/// This example shows how to use the AggregateUDFImpl::simplify API to
simplify/replace user
-/// defined aggregate function with a different expression which is defined in
the `simplify` method.
-
-#[derive(Debug, Clone)]
-struct BetterAvgUdaf {
- signature: Signature,
-}
-
-impl BetterAvgUdaf {
- /// Create a new instance of the GeoMeanUdaf struct
- fn new() -> Self {
- Self {
- signature: Signature::exact(vec![DataType::Float64],
Volatility::Immutable),
- }
- }
-}
-
-impl AggregateUDFImpl for BetterAvgUdaf {
- fn as_any(&self) -> &dyn Any {
- self
- }
-
- fn name(&self) -> &str {
- "better_avg"
- }
-
- fn signature(&self) -> &Signature {
- &self.signature
- }
-
- fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
- Ok(DataType::Float64)
- }
-
- fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn
Accumulator>> {
- unimplemented!("should not be invoked")
- }
-
- fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
- unimplemented!("should not be invoked")
- }
-
- fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
- true
- }
-
- fn create_groups_accumulator(
- &self,
- _args: AccumulatorArgs,
- ) -> Result<Box<dyn GroupsAccumulator>> {
- unimplemented!("should not get here");
- }
-
- // we override method, to return new expression which would substitute
- // user defined function call
- fn simplify(&self) -> Option<AggregateFunctionSimplification> {
- // as an example for this functionality we replace UDF function
- // with build-in aggregate function to illustrate the use
- let simplify = |aggregate_function: AggregateFunction, _: &dyn
SimplifyInfo| {
- Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
- avg_udaf(),
- // yes it is the same Avg, `BetterAvgUdaf` was just a
- // marketing pitch :)
- aggregate_function.args,
- aggregate_function.distinct,
- aggregate_function.filter,
- aggregate_function.order_by,
- aggregate_function.null_treatment,
- )))
- };
-
- Some(Box::new(simplify))
- }
-}
-
-// create local session context with an in-memory table
-fn create_context() -> Result<SessionContext> {
- use datafusion::datasource::MemTable;
- // define a schema.
- let schema = Arc::new(Schema::new(vec![
- Field::new("a", DataType::Float32, false),
- Field::new("b", DataType::Float32, false),
- ]));
-
- // define data in two partitions
- let batch1 = RecordBatch::try_new(
- schema.clone(),
- vec![
- Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])),
- Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])),
- ],
- )?;
- let batch2 = RecordBatch::try_new(
- schema.clone(),
- vec![
- Arc::new(Float32Array::from(vec![16.0])),
- Arc::new(Float32Array::from(vec![2.0])),
- ],
- )?;
-
- let ctx = SessionContext::new();
-
- // declare a table in memory. In spark API, this corresponds to
createDataFrame(...).
- let provider = MemTable::try_new(schema, vec![vec![batch1],
vec![batch2]])?;
- ctx.register_table("t", Arc::new(provider))?;
- Ok(ctx)
-}
-
-#[tokio::main]
-async fn main() -> Result<()> {
- let ctx = create_context()?;
-
- let better_avg = AggregateUDF::from(BetterAvgUdaf::new());
- ctx.register_udaf(better_avg.clone());
-
- let result = ctx
- .sql("SELECT better_avg(a) FROM t group by b")
- .await?
- .collect()
- .await?;
-
- let expected = [
- "+-----------------+",
- "| better_avg(t.a) |",
- "+-----------------+",
- "| 7.5 |",
- "+-----------------+",
- ];
-
- assert_batches_eq!(expected, &result);
-
- let df = ctx.table("t").await?;
- let df = df.aggregate(vec![], vec![better_avg.call(vec![col("a")])])?;
-
- let results = df.collect().await?;
- let result = as_float64_array(results[0].column(0))?;
-
- assert!((result.value(0) - 7.5).abs() < f64::EPSILON);
- println!("The average of [2,4,8,16] is {}", result.value(0));
-
- Ok(())
-}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]