alamb commented on code in PR #8171:
URL: https://github.com/apache/arrow-datafusion/pull/8171#discussion_r1394621628
##########
docs/source/library-user-guide/adding-udfs.md:
##########
@@ -115,10 +122,313 @@ let df = ctx.sql(&sql).await.unwrap();
Scalar UDFs are functions that take a row of data and return a single value.
Window UDFs are similar, but they also have access to the rows around them.
Access to the the proximal rows is helpful, but adds some complexity to the
implementation.
-Body coming soon.
+For example, we will declare a user defined window function that computes a
moving average.
+
+```rust
+use datafusion::arrow::{array::{ArrayRef, Float64Array, AsArray},
datatypes::Float64Type};
+use datafusion::logical_expr::{PartitionEvaluator};
+use datafusion::common::ScalarValue;
+use datafusion::error::Result;
+/// This implements the lowest level evaluation for a window function
+///
+/// It handles calculating the value of the window function for each
+/// distinct values of `PARTITION BY`
+#[derive(Clone, Debug)]
+struct MyPartitionEvaluator {}
+
+impl MyPartitionEvaluator {
+ fn new() -> Self {
+ Self {}
+ }
+}
+
+/// Different evaluation methods are called depending on the various
+/// settings of WindowUDF. This example uses the simplest and most
+/// general, `evaluate`. See `PartitionEvaluator` for the other more
+/// advanced uses.
+impl PartitionEvaluator for MyPartitionEvaluator {
+ /// Tell DataFusion the window function varies based on the value
+ /// of the window frame.
+ fn uses_window_frame(&self) -> bool {
+ true
+ }
+
+ /// This function is called once per input row.
+ ///
+ /// `range`specifies which indexes of `values` should be
+ /// considered for the calculation.
+ ///
+ /// Note this is the SLOWEST, but simplest, way to evaluate a
+ /// window function. It is much faster to implement
+ /// evaluate_all or evaluate_all_with_rank, if possible
+ fn evaluate(
+ &mut self,
+ values: &[ArrayRef],
+ range: &std::ops::Range<usize>,
+ ) -> Result<ScalarValue> {
+ // Again, the input argument is an array of floating
+ // point numbers to calculate a moving average
+ let arr: &Float64Array =
values[0].as_ref().as_primitive::<Float64Type>();
+
+ let range_len = range.end - range.start;
+
+ // our smoothing function will average all the values in the
+ let output = if range_len > 0 {
+ let sum: f64 =
arr.values().iter().skip(range.start).take(range_len).sum();
+ Some(sum / range_len as f64)
+ } else {
+ None
+ };
+
+ Ok(ScalarValue::Float64(output))
+ }
+}
+
+/// Create a `PartitionEvalutor` to evaluate this function on a new
+/// partition.
+fn make_partition_evaluator() -> Result<Box<dyn PartitionEvaluator>> {
+ Ok(Box::new(MyPartitionEvaluator::new()))
+}
+```
+
+### Registering a Window UDF
+
+To register a Window UDF, you need to wrap the function implementation in a
`WindowUDF` struct and then register it with the `SessionContext`. DataFusion
provides the `create_udwf` helper functions to make this easier.
+
+```rust
+use datafusion::logical_expr::{Volatility, create_udwf};
+use datafusion::arrow::datatypes::DataType;
+use std::sync::Arc;
+
+// here is where we define the UDWF. We also declare its signature:
+let smooth_it = create_udwf(
+ "smooth_it",
+ DataType::Float64,
+ Arc::new(DataType::Float64),
+ Volatility::Immutable,
+ Arc::new(make_partition_evaluator),
+);
+```
+
+The `create_udwf` has five arguments to check:
+
+- The first argument is the name of the function. This is the name that will
be used in SQL queries.
+- **The second argument** is the `DataType` of input array (attention: this is
not a list of arrays). I.e. in this case, the function accepts `Float64` as
argument.
+- The third argument is the return type of the function. I.e. in this case,
the function returns an `Float64`.
+- The fourth argument is the volatility of the function. In short, this is
used to determine if the function's performance can be optimized in some
situations. In this case, the function is `Immutable` because it always returns
the same value for the same input. A random number generator would be
`Volatile` because it returns a different value for the same input.
+- **The fifth argument** is the function implementation. This is the function
that we defined above.
+
+That gives us a `WindowUDF` that we can register with the `SessionContext`:
+
+```rust
+use datafusion::execution::context::SessionContext;
+
+let ctx = SessionContext::new();
+
+ctx.register_udwf(smooth_it);
+```
+
+At this point, you can use the `smooth_it` function in your query:
+
+For example, if we have a
[`cars.csv`](https://github.com/apache/arrow-datafusion/blob/main/datafusion/core/tests/data/cars.csv)
whose contents like
+
+```csv
+car,speed,time
+red,20.0,1996-04-12T12:05:03.000000000
+red,20.3,1996-04-12T12:05:04.000000000
+green,10.0,1996-04-12T12:05:03.000000000
+green,10.3,1996-04-12T12:05:04.000000000
+...
+```
+
+Then, we can query like below:
+
+```rust
+use datafusion::datasource::file_format::options::CsvReadOptions;
+// register csv table first
+let csv_path = "cars.csv".to_string();
+ctx.register_csv("cars", &csv_path,
CsvReadOptions::default().has_header(true)).await?;
+// do query with smooth_it
+let df = ctx
+ .sql(
+ "SELECT \
+ car, \
+ speed, \
+ smooth_it(speed) OVER (PARTITION BY car ORDER BY time) as
smooth_speed,\
+ time \
+ from cars \
+ ORDER BY \
+ car",
+ )
+ .await?;
+// print the results
+df.show().await?;
+```
+
+the output will be like:
+
+```csv
++-------+-------+--------------------+---------------------+
+| car | speed | smooth_speed | time |
++-------+-------+--------------------+---------------------+
+| green | 10.0 | 10.0 | 1996-04-12T12:05:03 |
+| green | 10.3 | 10.15 | 1996-04-12T12:05:04 |
+| green | 10.4 | 10.233333333333334 | 1996-04-12T12:05:05 |
+| green | 10.5 | 10.3 | 1996-04-12T12:05:06 |
+| green | 11.0 | 10.440000000000001 | 1996-04-12T12:05:07 |
+| green | 12.0 | 10.700000000000001 | 1996-04-12T12:05:08 |
+| green | 14.0 | 11.171428571428573 | 1996-04-12T12:05:09 |
+| green | 15.0 | 11.65 | 1996-04-12T12:05:10 |
+| green | 15.1 | 12.033333333333333 | 1996-04-12T12:05:11 |
+| green | 15.2 | 12.35 | 1996-04-12T12:05:12 |
+| green | 8.0 | 11.954545454545455 | 1996-04-12T12:05:13 |
+| green | 2.0 | 11.125 | 1996-04-12T12:05:14 |
+| red | 20.0 | 20.0 | 1996-04-12T12:05:03 |
+| red | 20.3 | 20.15 | 1996-04-12T12:05:04 |
+...
+```
## Adding an Aggregate UDF
Aggregate UDFs are functions that take a group of rows and return a single
value. These are akin to SQL's `SUM` or `COUNT` functions.
-Body coming soon.
+For example, we will declare a single-type, single return type UDAF that
computes the geometric mean.
+
+```rust
+use datafusion::arrow::array::ArrayRef;
+use datafusion::scalar::ScalarValue;
+use datafusion::{error::Result, physical_plan::Accumulator};
+
+/// A UDAF has state across multiple rows, and thus we require a `struct` with
that state.
+#[derive(Debug)]
+struct GeometricMean {
+ n: u32,
+ prod: f64,
+}
+
+impl GeometricMean {
+ // how the struct is initialized
+ pub fn new() -> Self {
+ GeometricMean { n: 0, prod: 1.0 }
+ }
+}
+
+// UDAFs are built using the trait `Accumulator`, that offers DataFusion the
necessary functions
+// to use them.
+impl Accumulator for GeometricMean {
+ // This function serializes our state to `ScalarValue`, which DataFusion
uses
+ // to pass this state between execution stages.
+ // Note that this can be arbitrary data.
+ fn state(&self) -> Result<Vec<ScalarValue>> {
+ Ok(vec![
+ ScalarValue::from(self.prod),
+ ScalarValue::from(self.n),
+ ])
+ }
+
+ // DataFusion expects this function to return the final value of this
aggregator.
+ // in this case, this is the formula of the geometric mean
+ fn evaluate(&self) -> Result<ScalarValue> {
+ let value = self.prod.powf(1.0 / self.n as f64);
+ Ok(ScalarValue::from(value))
+ }
+
+ // DataFusion calls this function to update the accumulator's state for a
batch
+ // of inputs rows. In this case the product is updated with values from
the first column
+ // and the count is updated based on the row count
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ if values.is_empty() {
+ return Ok(());
+ }
+ let arr = &values[0];
+ (0..arr.len()).try_for_each(|index| {
+ let v = ScalarValue::try_from_array(arr, index)?;
+
+ if let ScalarValue::Float64(Some(value)) = v {
+ self.prod *= value;
+ self.n += 1;
+ } else {
+ unreachable!("")
+ }
+ Ok(())
+ })
+ }
+
+ // Optimization hint: this trait also supports `update_batch` and
`merge_batch`,
Review Comment:
this comment seems out of date as it is on `merge_batch`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]