alamb commented on a change in pull request #9376:
URL: https://github.com/apache/arrow/pull/9376#discussion_r568535264



##########
File path: rust/datafusion/examples/simple_udf.rs
##########
@@ -58,77 +54,54 @@ fn create_context() -> Result<ExecutionContext> {
     Ok(ctx)
 }
 
-// a small utility function to compute pow(base, exponent)
-fn maybe_pow(base: &Option<f64>, exponent: &Option<f64>) -> Option<f64> {
-    match (base, exponent) {
-        // in arrow, any value can be null.
-        // Here we decide to make our UDF to return null when either base or 
exponent is null.
-        (Some(base), Some(exponent)) => Some(base.powf(*exponent)),
-        _ => None,
-    }
-}
-
-fn pow_array(base: &dyn Array, exponent: &dyn Array) -> Result<ArrayRef> {
-    // 1. cast both arguments to f64. These casts MUST be aligned with the 
signature or this function panics!
-    let base = base
-        .as_any()
-        .downcast_ref::<Float64Array>()
-        .expect("cast failed");
-    let exponent = exponent
-        .as_any()
-        .downcast_ref::<Float64Array>()
-        .expect("cast failed");
-
-    // this is guaranteed by DataFusion. We place it just to make it obvious.
-    assert_eq!(exponent.len(), base.len());
-
-    // 2. perform the computation
-    let array = base
-        .iter()
-        .zip(exponent.iter())
-        .map(|(base, exponent)| maybe_pow(&base, &exponent))
-        .collect::<Float64Array>();
-
-    // `Ok` because no error occurred during the calculation (we should add 
one if exponent was [0, 1[ and the base < 0 because that panics!)
-    // `Arc` because arrays are immutable, thread-safe, trait objects.
-    Ok(Arc::new(array))
-}
-
 /// In this example we will declare a single-type, single return type UDF that 
exponentiates f64, a^b
 #[tokio::main]
 async fn main() -> Result<()> {
     let mut ctx = create_context()?;
 
     // First, declare the actual implementation of the calculation
-    let pow: ScalarFunctionImplementation = Arc::new(|args: &[ColumnarValue]| {
-        // in DataFusion, all `args` and output are `ColumnarValue`, an enum 
of either a scalar or a dynamically-typed array.
-        // we can cater for both, or document that the UDF only supports some 
variants.
-        // here we will assume that al
+    let pow = |args: &[ArrayRef]| {
+        // in DataFusion, all `args` and output are dynamically-typed arrays, 
which means that we need to:
         // 1. cast the values to the type we want
         // 2. perform the computation for every element in the array (using a 
loop or SIMD) and construct the result
 
         // this is guaranteed by DataFusion based on the function's signature.
         assert_eq!(args.len(), 2);
 
-        let (base, exponent) = (&args[0], &args[1]);
-
-        let result = match (base, exponent) {
-            (
-                ColumnarValue::Scalar(ScalarValue::Float64(base)),
-                ColumnarValue::Scalar(ScalarValue::Float64(exponent)),
-            ) => ColumnarValue::Scalar(ScalarValue::Float64(maybe_pow(base, 
exponent))),
-            (ColumnarValue::Array(base), ColumnarValue::Array(exponent)) => {
-                let array = pow_array(base.as_ref(), exponent.as_ref())?;
-                ColumnarValue::Array(array)
-            }
-            _ => {
-                return Err(DataFusionError::Execution(
-                    "This UDF only supports f64".to_string(),
-                ))
-            }
-        };
-        Ok(result)
-    });
+        // 1. cast both arguments to f64. These casts MUST be aligned with the 
signature or this function panics!
+        let base = &args[0]
+            .as_any()
+            .downcast_ref::<Float64Array>()
+            .expect("cast failed");
+        let exponent = &args[1]
+            .as_any()
+            .downcast_ref::<Float64Array>()
+            .expect("cast failed");
+
+        // this is guaranteed by DataFusion. We place it just to make it 
obvious.
+        assert_eq!(exponent.len(), base.len());
+
+        // 2. perform the computation
+        let array = base
+            .iter()
+            .zip(exponent.iter())
+            .map(|(base, exponent)| {
+                match (base, exponent) {
+                    // in arrow, any value can be null.
+                    // Here we decide to make our UDF to return null when 
either base or exponent is null.
+                    (Some(base), Some(exponent)) => Some(base.powf(exponent)),
+                    _ => None,
+                }
+            })
+            .collect::<Float64Array>();
+
+        // `Ok` because no error occurred during the calculation (we should 
add one if exponent was [0, 1[ and the base < 0 because that panics!)
+        // `Arc` because arrays are immutable, thread-safe, trait objects.
+        Ok(Arc::new(array) as ArrayRef)
+    };
+    // the function above expects an `ArrayRef`, but DataFusion may pass a 
scalar to a UDF.
+    // thus, we use `make_scalar_function` to decorare the closure so that it 
can handle both Arrays and Scalar values.

Review comment:
       👍 

##########
File path: rust/datafusion/examples/simple_udf.rs
##########
@@ -54,50 +58,76 @@ fn create_context() -> Result<ExecutionContext> {
     Ok(ctx)
 }
 
+// a small utility function to compute pow(base, exponent)
+fn maybe_pow(base: &Option<f64>, exponent: &Option<f64>) -> Option<f64> {
+    match (base, exponent) {
+        // in arrow, any value can be null.
+        // Here we decide to make our UDF to return null when either base or 
exponent is null.
+        (Some(base), Some(exponent)) => Some(base.powf(*exponent)),
+        _ => None,
+    }
+}
+
+fn pow_array(base: &dyn Array, exponent: &dyn Array) -> Result<ArrayRef> {
+    // 1. cast both arguments to f64. These casts MUST be aligned with the 
signature or this function panics!
+    let base = base
+        .as_any()
+        .downcast_ref::<Float64Array>()
+        .expect("cast failed");
+    let exponent = exponent
+        .as_any()
+        .downcast_ref::<Float64Array>()
+        .expect("cast failed");
+
+    // this is guaranteed by DataFusion. We place it just to make it obvious.
+    assert_eq!(exponent.len(), base.len());
+
+    // 2. perform the computation
+    let array = base
+        .iter()
+        .zip(exponent.iter())
+        .map(|(base, exponent)| maybe_pow(&base, &exponent))
+        .collect::<Float64Array>();
+
+    // `Ok` because no error occurred during the calculation (we should add 
one if exponent was [0, 1[ and the base < 0 because that panics!)

Review comment:
       cool -- sorry I guess I am used to seeing an open interval  using a `)` 
--so in this case something like `[0, 1)` to represent `0 <= exponent < 1` 
(e.g. [here](https://en.wikipedia.org/wiki/Interval_(mathematics)#Terminology)




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to