This is an automated email from the ASF dual-hosted git repository.

xudong963 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 933fec845e Consolidate example: simplify_udaf_expression.rs into 
advanced_udaf.rs (#13905)
933fec845e is described below

commit 933fec845e02bb8983a0b932cfa12ebb27054748
Author: Takahiro Ebato <[email protected]>
AuthorDate: Fri Dec 27 15:32:56 2024 +0900

    Consolidate example: simplify_udaf_expression.rs into advanced_udaf.rs 
(#13905)
---
 datafusion-examples/examples/advanced_udaf.rs      | 185 +++++++++++++++------
 .../examples/simplify_udaf_expression.rs           | 176 --------------------
 2 files changed, 132 insertions(+), 229 deletions(-)

diff --git a/datafusion-examples/examples/advanced_udaf.rs 
b/datafusion-examples/examples/advanced_udaf.rs
index 414596bdc6..a914cea4a9 100644
--- a/datafusion-examples/examples/advanced_udaf.rs
+++ b/datafusion-examples/examples/advanced_udaf.rs
@@ -31,7 +31,9 @@ use datafusion::error::Result;
 use datafusion::prelude::*;
 use datafusion_common::{cast::as_float64_array, ScalarValue};
 use datafusion_expr::{
-    function::{AccumulatorArgs, StateFieldsArgs},
+    expr::AggregateFunction,
+    function::{AccumulatorArgs, AggregateFunctionSimplification, 
StateFieldsArgs},
+    simplify::SimplifyInfo,
     Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
 };
 
@@ -197,40 +199,6 @@ impl Accumulator for GeometricMean {
     }
 }
 
-// create local session context with an in-memory table
-fn create_context() -> Result<SessionContext> {
-    use datafusion::datasource::MemTable;
-    // define a schema.
-    let schema = Arc::new(Schema::new(vec![
-        Field::new("a", DataType::Float32, false),
-        Field::new("b", DataType::Float32, false),
-    ]));
-
-    // define data in two partitions
-    let batch1 = RecordBatch::try_new(
-        schema.clone(),
-        vec![
-            Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])),
-            Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])),
-        ],
-    )?;
-    let batch2 = RecordBatch::try_new(
-        schema.clone(),
-        vec![
-            Arc::new(Float32Array::from(vec![64.0])),
-            Arc::new(Float32Array::from(vec![2.0])),
-        ],
-    )?;
-
-    // declare a new context. In spark API, this corresponds to a new spark 
SQLsession
-    let ctx = SessionContext::new();
-
-    // declare a table in memory. In spark API, this corresponds to 
createDataFrame(...).
-    let provider = MemTable::try_new(schema, vec![vec![batch1], 
vec![batch2]])?;
-    ctx.register_table("t", Arc::new(provider))?;
-    Ok(ctx)
-}
-
 // Define a `GroupsAccumulator` for GeometricMean
 /// which handles accumulator state for multiple groups at once.
 /// This API is significantly more complicated than `Accumulator`, which 
manages
@@ -399,35 +367,146 @@ impl GroupsAccumulator for 
GeometricMeanGroupsAccumulator {
     }
 }
 
+/// This example shows how to use the AggregateUDFImpl::simplify API to 
simplify/replace user
+/// defined aggregate function with a different expression which is defined in 
the `simplify` method.
+#[derive(Debug, Clone)]
+struct SimplifiedGeoMeanUdaf {
+    signature: Signature,
+}
+
+impl SimplifiedGeoMeanUdaf {
+    fn new() -> Self {
+        Self {
+            signature: Signature::exact(vec![DataType::Float64], 
Volatility::Immutable),
+        }
+    }
+}
+
+impl AggregateUDFImpl for SimplifiedGeoMeanUdaf {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn name(&self) -> &str {
+        "simplified_geo_mean"
+    }
+
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+
+    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+        Ok(DataType::Float64)
+    }
+
+    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn 
Accumulator>> {
+        unimplemented!("should not be invoked")
+    }
+
+    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
+        unimplemented!("should not be invoked")
+    }
+
+    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
+        true
+    }
+
+    fn create_groups_accumulator(
+        &self,
+        _args: AccumulatorArgs,
+    ) -> Result<Box<dyn GroupsAccumulator>> {
+        unimplemented!("should not get here");
+    }
+
+    /// Optionally replaces a UDAF with another expression during query 
optimization.
+    fn simplify(&self) -> Option<AggregateFunctionSimplification> {
+        let simplify = |aggregate_function: AggregateFunction, _: &dyn 
SimplifyInfo| {
+            // Replaces the UDAF with `GeoMeanUdaf` as a placeholder example 
to demonstrate the `simplify` method.
+            // In real-world scenarios, you might create UDFs from built-in 
expressions.
+            Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
+                Arc::new(AggregateUDF::from(GeoMeanUdaf::new())),
+                aggregate_function.args,
+                aggregate_function.distinct,
+                aggregate_function.filter,
+                aggregate_function.order_by,
+                aggregate_function.null_treatment,
+            )))
+        };
+        Some(Box::new(simplify))
+    }
+}
+
+// create local session context with an in-memory table
+fn create_context() -> Result<SessionContext> {
+    use datafusion::datasource::MemTable;
+    // define a schema.
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("a", DataType::Float32, false),
+        Field::new("b", DataType::Float32, false),
+    ]));
+
+    // define data in two partitions
+    let batch1 = RecordBatch::try_new(
+        schema.clone(),
+        vec![
+            Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])),
+            Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])),
+        ],
+    )?;
+    let batch2 = RecordBatch::try_new(
+        schema.clone(),
+        vec![
+            Arc::new(Float32Array::from(vec![64.0])),
+            Arc::new(Float32Array::from(vec![2.0])),
+        ],
+    )?;
+
+    // declare a new context. In spark API, this corresponds to a new spark 
SQLsession
+    let ctx = SessionContext::new();
+
+    // declare a table in memory. In spark API, this corresponds to 
createDataFrame(...).
+    let provider = MemTable::try_new(schema, vec![vec![batch1], 
vec![batch2]])?;
+    ctx.register_table("t", Arc::new(provider))?;
+    Ok(ctx)
+}
+
 #[tokio::main]
 async fn main() -> Result<()> {
     let ctx = create_context()?;
 
-    // create the AggregateUDF
-    let geometric_mean = AggregateUDF::from(GeoMeanUdaf::new());
-    ctx.register_udaf(geometric_mean.clone());
+    let geo_mean_udf = AggregateUDF::from(GeoMeanUdaf::new());
+    let simplified_geo_mean_udf = 
AggregateUDF::from(SimplifiedGeoMeanUdaf::new());
+
+    for (udf, udf_name) in [
+        (geo_mean_udf, "geo_mean"),
+        (simplified_geo_mean_udf, "simplified_geo_mean"),
+    ] {
+        ctx.register_udaf(udf.clone());
 
-    let sql_df = ctx.sql("SELECT geo_mean(a) FROM t group by b").await?;
-    sql_df.show().await?;
+        let sql_df = ctx
+            .sql(&format!("SELECT {}(a) FROM t GROUP BY b", udf_name))
+            .await?;
+        sql_df.show().await?;
 
-    // get a DataFrame from the context
-    // this table has 1 column `a` f32 with values {2,4,8,64}, whose geometric 
mean is 8.0.
-    let df = ctx.table("t").await?;
+        // get a DataFrame from the context
+        // this table has 1 column `a` f32 with values {2,4,8,64}, whose 
geometric mean is 8.0.
+        let df = ctx.table("t").await?;
 
-    // perform the aggregation
-    let df = df.aggregate(vec![], vec![geometric_mean.call(vec![col("a")])])?;
+        // perform the aggregation
+        let df = df.aggregate(vec![], vec![udf.call(vec![col("a")])])?;
 
-    // note that "a" is f32, not f64. DataFusion coerces it to match the 
UDAF's signature.
+        // note that "a" is f32, not f64. DataFusion coerces it to match the 
UDAF's signature.
 
-    // execute the query
-    let results = df.collect().await?;
+        // execute the query
+        let results = df.collect().await?;
 
-    // downcast the array to the expected type
-    let result = as_float64_array(results[0].column(0))?;
+        // downcast the array to the expected type
+        let result = as_float64_array(results[0].column(0))?;
 
-    // verify that the calculation is correct
-    assert!((result.value(0) - 8.0).abs() < f64::EPSILON);
-    println!("The geometric mean of [2,4,8,64] is {}", result.value(0));
+        // verify that the calculation is correct
+        assert!((result.value(0) - 8.0).abs() < f64::EPSILON);
+        println!("The geometric mean of [2,4,8,64] is {}", result.value(0));
+    }
 
     Ok(())
 }
diff --git a/datafusion-examples/examples/simplify_udaf_expression.rs 
b/datafusion-examples/examples/simplify_udaf_expression.rs
deleted file mode 100644
index 52a27317e3..0000000000
--- a/datafusion-examples/examples/simplify_udaf_expression.rs
+++ /dev/null
@@ -1,176 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements.  See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership.  The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License.  You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied.  See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-use std::{any::Any, sync::Arc};
-
-use arrow_schema::{Field, Schema};
-
-use datafusion::arrow::{array::Float32Array, record_batch::RecordBatch};
-use datafusion::error::Result;
-use datafusion::functions_aggregate::average::avg_udaf;
-use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
-use datafusion::{assert_batches_eq, prelude::*};
-use datafusion_common::cast::as_float64_array;
-use datafusion_expr::function::{AggregateFunctionSimplification, 
StateFieldsArgs};
-use datafusion_expr::simplify::SimplifyInfo;
-use datafusion_expr::{
-    expr::AggregateFunction, function::AccumulatorArgs, Accumulator, 
AggregateUDF,
-    AggregateUDFImpl, GroupsAccumulator, Signature,
-};
-
-/// This example shows how to use the AggregateUDFImpl::simplify API to 
simplify/replace user
-/// defined aggregate function with a different expression which is defined in 
the `simplify` method.
-
-#[derive(Debug, Clone)]
-struct BetterAvgUdaf {
-    signature: Signature,
-}
-
-impl BetterAvgUdaf {
-    /// Create a new instance of the GeoMeanUdaf struct
-    fn new() -> Self {
-        Self {
-            signature: Signature::exact(vec![DataType::Float64], 
Volatility::Immutable),
-        }
-    }
-}
-
-impl AggregateUDFImpl for BetterAvgUdaf {
-    fn as_any(&self) -> &dyn Any {
-        self
-    }
-
-    fn name(&self) -> &str {
-        "better_avg"
-    }
-
-    fn signature(&self) -> &Signature {
-        &self.signature
-    }
-
-    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
-        Ok(DataType::Float64)
-    }
-
-    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn 
Accumulator>> {
-        unimplemented!("should not be invoked")
-    }
-
-    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
-        unimplemented!("should not be invoked")
-    }
-
-    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
-        true
-    }
-
-    fn create_groups_accumulator(
-        &self,
-        _args: AccumulatorArgs,
-    ) -> Result<Box<dyn GroupsAccumulator>> {
-        unimplemented!("should not get here");
-    }
-
-    // we override method, to return new expression which would substitute
-    // user defined function call
-    fn simplify(&self) -> Option<AggregateFunctionSimplification> {
-        // as an example for this functionality we replace UDF function
-        // with build-in aggregate function to illustrate the use
-        let simplify = |aggregate_function: AggregateFunction, _: &dyn 
SimplifyInfo| {
-            Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
-                avg_udaf(),
-                // yes it is the same Avg, `BetterAvgUdaf` was just a
-                // marketing pitch :)
-                aggregate_function.args,
-                aggregate_function.distinct,
-                aggregate_function.filter,
-                aggregate_function.order_by,
-                aggregate_function.null_treatment,
-            )))
-        };
-
-        Some(Box::new(simplify))
-    }
-}
-
-// create local session context with an in-memory table
-fn create_context() -> Result<SessionContext> {
-    use datafusion::datasource::MemTable;
-    // define a schema.
-    let schema = Arc::new(Schema::new(vec![
-        Field::new("a", DataType::Float32, false),
-        Field::new("b", DataType::Float32, false),
-    ]));
-
-    // define data in two partitions
-    let batch1 = RecordBatch::try_new(
-        schema.clone(),
-        vec![
-            Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])),
-            Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])),
-        ],
-    )?;
-    let batch2 = RecordBatch::try_new(
-        schema.clone(),
-        vec![
-            Arc::new(Float32Array::from(vec![16.0])),
-            Arc::new(Float32Array::from(vec![2.0])),
-        ],
-    )?;
-
-    let ctx = SessionContext::new();
-
-    // declare a table in memory. In spark API, this corresponds to 
createDataFrame(...).
-    let provider = MemTable::try_new(schema, vec![vec![batch1], 
vec![batch2]])?;
-    ctx.register_table("t", Arc::new(provider))?;
-    Ok(ctx)
-}
-
-#[tokio::main]
-async fn main() -> Result<()> {
-    let ctx = create_context()?;
-
-    let better_avg = AggregateUDF::from(BetterAvgUdaf::new());
-    ctx.register_udaf(better_avg.clone());
-
-    let result = ctx
-        .sql("SELECT better_avg(a) FROM t group by b")
-        .await?
-        .collect()
-        .await?;
-
-    let expected = [
-        "+-----------------+",
-        "| better_avg(t.a) |",
-        "+-----------------+",
-        "| 7.5             |",
-        "+-----------------+",
-    ];
-
-    assert_batches_eq!(expected, &result);
-
-    let df = ctx.table("t").await?;
-    let df = df.aggregate(vec![], vec![better_avg.call(vec![col("a")])])?;
-
-    let results = df.collect().await?;
-    let result = as_float64_array(results[0].column(0))?;
-
-    assert!((result.value(0) - 7.5).abs() < f64::EPSILON);
-    println!("The average of [2,4,8,16] is {}", result.value(0));
-
-    Ok(())
-}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to