stuartcarnie commented on code in PR #6617:
URL: https://github.com/apache/arrow-datafusion/pull/6617#discussion_r1228969301


##########
datafusion-examples/examples/simple_udwf.rs:
##########
@@ -0,0 +1,210 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::sync::Arc;
+
+use arrow::{
+    array::{AsArray, Float64Array},
+    datatypes::Float64Type,
+};
+use arrow_schema::DataType;
+use datafusion::datasource::file_format::options::CsvReadOptions;
+
+use datafusion::error::Result;
+use datafusion::prelude::*;
+use datafusion_common::DataFusionError;
+use datafusion_expr::{
+    partition_evaluator::PartitionEvaluator, Signature, Volatility, WindowUDF,
+};
+
+// create local execution context with `cars.csv` registered as a table named 
`cars`
+async fn create_context() -> Result<SessionContext> {
+    // declare a new context. In spark API, this corresponds to a new spark 
SQLsession
+    let ctx = SessionContext::new();
+
+    // declare a table in memory. In spark API, this corresponds to 
createDataFrame(...).
+    println!("pwd: {}", std::env::current_dir().unwrap().display());
+    let csv_path = format!("datafusion/core/tests/data/cars.csv");
+    let read_options = CsvReadOptions::default().has_header(true);
+
+    ctx.register_csv("cars", &csv_path, read_options).await?;
+    Ok(ctx)
+}
+
+/// In this example we will declare a user defined window function that 
computes a moving average and then run it using SQL
+#[tokio::main]
+async fn main() -> Result<()> {
+    let ctx = create_context().await?;
+
+    // register the window function with DataFusion so wecan call it
+    ctx.register_udwf(my_average());
+
+    // Use SQL to run the new window function
+    let df = ctx.sql("SELECT * from cars").await?;
+    // print the results
+    df.show().await?;
+
+    // Use SQL to run the new window function
+    // `PARTITION BY car`:each distinct value of car (red, and green) should 
be treated separately
+    // `ORDER BY time`: within each group (greed or green) the values will be 
orderd by time
+    let df = ctx
+        .sql(
+            "SELECT car, \
+                      speed, \
+                      lag(speed, 1) OVER (PARTITION BY car ORDER BY time),\
+                      my_average(speed) OVER (PARTITION BY car ORDER BY time),\
+                      time \
+                      from cars",
+        )
+        .await?;
+    // print the results
+    df.show().await?;
+
+    // // ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING: Run the window functon so 
that each invocation only sees 5 rows: the 2 before and 2 after) using
+    // let df = ctx.sql("SELECT car, \
+    //                   speed, \
+    //                   lag(speed, 1) OVER (PARTITION BY car ORDER BY time 
ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING),\
+    //                   time \
+    //                   from cars").await?;
+    // // print the results
+    // df.show().await?;
+
+    // todo show how to run dataframe API as well
+
+    Ok(())
+}
+
+// TODO make a helper funciton like `crate_udf` that helps to make these 
signatures
+
+fn my_average() -> WindowUDF {
+    WindowUDF {
+        name: String::from("my_average"),
+        // it will take 2 arguments -- the column and the window size
+        signature: Signature::exact(vec![DataType::Int32], 
Volatility::Immutable),
+        return_type: Arc::new(return_type),
+        partition_evaluator: Arc::new(make_partition_evaluator),
+    }
+}
+
+/// Compute the return type of the function given the argument types
+fn return_type(arg_types: &[DataType]) -> Result<Arc<DataType>> {
+    if arg_types.len() != 1 {
+        return Err(DataFusionError::Plan(format!(
+            "my_udwf expects 1 argument, got {}: {:?}",
+            arg_types.len(),
+            arg_types
+        )));
+    }
+    Ok(Arc::new(arg_types[0].clone()))
+}
+
+/// Create a partition evaluator for this argument
+fn make_partition_evaluator() -> Result<Box<dyn PartitionEvaluator>> {
+    Ok(Box::new(MyPartitionEvaluator::new()))
+}

Review Comment:
   What do you think of passing the `PhysicalExpr` trait objects? Example:
   
   ```rust
   /// Factory that creates a PartitionEvaluator for the given window function.
   ///
   /// This function is passed its input arguments so that cases such as
   /// constants can be correctly handled.
   pub type PartitionEvaluatorFunctionFactory =
       Arc<dyn Fn(&[Arc<dyn PhysicalExpr>], &Schema) -> Result<Box<dyn 
PartitionEvaluator>> + Send + Sync>;
   ```
   
   > **Note**
   >
   > I've also included the `input_schema`, as this would be necessary to 
evaluate types for the arguments.
   
   This would be similar to the `create_built_in_window_expr`:
   
   
https://github.com/apache/arrow-datafusion/blob/a42cc8d98b6e875c485e7e9b106d30803a32b00a/datafusion/core/src/physical_plan/windows/mod.rs#L120-L125
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to