alamb commented on a change in pull request #9376:
URL: https://github.com/apache/arrow/pull/9376#discussion_r568532324



##########
File path: rust/datafusion/src/physical_plan/datetime_expressions.rs
##########
@@ -167,152 +175,166 @@ fn naive_datetime_to_timestamp(s: &str, datetime: 
NaiveDateTime) -> Result<i64>
     }
 }
 
-/// convert an array of strings into `Timestamp(Nanosecond, None)`
-pub fn to_timestamp(args: &[ArrayRef]) -> Result<TimestampNanosecondArray> {
-    let num_rows = args[0].len();
-    let string_args =
-        &args[0]
-            .as_any()
-            .downcast_ref::<StringArray>()
-            .ok_or_else(|| {
-                DataFusionError::Internal(
-                    "could not cast to_timestamp input to 
StringArray".to_string(),
-                )
-            })?;
-
-    let result = (0..num_rows)
-        .map(|i| {
-            if string_args.is_null(i) {
-                // NB: Since we use the same null bitset as the input,
-                // the output for this value will be ignored, but we
-                // need some value in the array we are building.
-                Ok(0)
-            } else {
-                string_to_timestamp_nanos(string_args.value(i))
+pub(crate) fn unary_string_to_primitive_function<'a, T, O, F>(
+    args: &[&'a dyn Array],
+    op: F,
+    name: &str,
+) -> Result<PrimitiveArray<O>>
+where
+    O: ArrowPrimitiveType,
+    T: StringOffsetSizeTrait,
+    F: Fn(&'a str) -> Result<O::Native>,
+{
+    if args.len() != 1 {
+        return Err(DataFusionError::Internal(format!(
+            "{:?} args were supplied but {} takes exactly one argument",
+            args.len(),
+            name,
+        )));
+    }
+
+    let array = args[0]
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+
+    // first map is the iterator, second is for the `Option<_>`
+    array.iter().map(|x| x.map(|x| op(x)).transpose()).collect()
+}
+
+fn handle<'a, O, F, S>(

Review comment:
       makes sense -- thank you for the explination

##########
File path: rust/datafusion/src/physical_plan/datetime_expressions.rs
##########
@@ -167,152 +175,166 @@ fn naive_datetime_to_timestamp(s: &str, datetime: 
NaiveDateTime) -> Result<i64>
     }
 }
 
-/// convert an array of strings into `Timestamp(Nanosecond, None)`
-pub fn to_timestamp(args: &[ArrayRef]) -> Result<TimestampNanosecondArray> {
-    let num_rows = args[0].len();
-    let string_args =
-        &args[0]
-            .as_any()
-            .downcast_ref::<StringArray>()
-            .ok_or_else(|| {
-                DataFusionError::Internal(
-                    "could not cast to_timestamp input to 
StringArray".to_string(),
-                )
-            })?;
-
-    let result = (0..num_rows)
-        .map(|i| {
-            if string_args.is_null(i) {
-                // NB: Since we use the same null bitset as the input,
-                // the output for this value will be ignored, but we
-                // need some value in the array we are building.
-                Ok(0)
-            } else {
-                string_to_timestamp_nanos(string_args.value(i))
+pub(crate) fn unary_string_to_primitive_function<'a, T, O, F>(
+    args: &[&'a dyn Array],
+    op: F,
+    name: &str,
+) -> Result<PrimitiveArray<O>>
+where
+    O: ArrowPrimitiveType,
+    T: StringOffsetSizeTrait,
+    F: Fn(&'a str) -> Result<O::Native>,
+{
+    if args.len() != 1 {
+        return Err(DataFusionError::Internal(format!(
+            "{:?} args were supplied but {} takes exactly one argument",
+            args.len(),
+            name,
+        )));
+    }
+
+    let array = args[0]
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+
+    // first map is the iterator, second is for the `Option<_>`
+    array.iter().map(|x| x.map(|x| op(x)).transpose()).collect()
+}
+
+fn handle<'a, O, F, S>(

Review comment:
       makes sense -- thank you for the explanation

##########
File path: rust/datafusion/examples/simple_udf.rs
##########
@@ -58,77 +54,54 @@ fn create_context() -> Result<ExecutionContext> {
     Ok(ctx)
 }
 
-// a small utility function to compute pow(base, exponent)
-fn maybe_pow(base: &Option<f64>, exponent: &Option<f64>) -> Option<f64> {
-    match (base, exponent) {
-        // in arrow, any value can be null.
-        // Here we decide to make our UDF to return null when either base or 
exponent is null.
-        (Some(base), Some(exponent)) => Some(base.powf(*exponent)),
-        _ => None,
-    }
-}
-
-fn pow_array(base: &dyn Array, exponent: &dyn Array) -> Result<ArrayRef> {
-    // 1. cast both arguments to f64. These casts MUST be aligned with the 
signature or this function panics!
-    let base = base
-        .as_any()
-        .downcast_ref::<Float64Array>()
-        .expect("cast failed");
-    let exponent = exponent
-        .as_any()
-        .downcast_ref::<Float64Array>()
-        .expect("cast failed");
-
-    // this is guaranteed by DataFusion. We place it just to make it obvious.
-    assert_eq!(exponent.len(), base.len());
-
-    // 2. perform the computation
-    let array = base
-        .iter()
-        .zip(exponent.iter())
-        .map(|(base, exponent)| maybe_pow(&base, &exponent))
-        .collect::<Float64Array>();
-
-    // `Ok` because no error occurred during the calculation (we should add 
one if exponent was [0, 1[ and the base < 0 because that panics!)
-    // `Arc` because arrays are immutable, thread-safe, trait objects.
-    Ok(Arc::new(array))
-}
-
 /// In this example we will declare a single-type, single return type UDF that 
exponentiates f64, a^b
 #[tokio::main]
 async fn main() -> Result<()> {
     let mut ctx = create_context()?;
 
     // First, declare the actual implementation of the calculation
-    let pow: ScalarFunctionImplementation = Arc::new(|args: &[ColumnarValue]| {
-        // in DataFusion, all `args` and output are `ColumnarValue`, an enum 
of either a scalar or a dynamically-typed array.
-        // we can cater for both, or document that the UDF only supports some 
variants.
-        // here we will assume that al
+    let pow = |args: &[ArrayRef]| {
+        // in DataFusion, all `args` and output are dynamically-typed arrays, 
which means that we need to:
         // 1. cast the values to the type we want
         // 2. perform the computation for every element in the array (using a 
loop or SIMD) and construct the result
 
         // this is guaranteed by DataFusion based on the function's signature.
         assert_eq!(args.len(), 2);
 
-        let (base, exponent) = (&args[0], &args[1]);
-
-        let result = match (base, exponent) {
-            (
-                ColumnarValue::Scalar(ScalarValue::Float64(base)),
-                ColumnarValue::Scalar(ScalarValue::Float64(exponent)),
-            ) => ColumnarValue::Scalar(ScalarValue::Float64(maybe_pow(base, 
exponent))),
-            (ColumnarValue::Array(base), ColumnarValue::Array(exponent)) => {
-                let array = pow_array(base.as_ref(), exponent.as_ref())?;
-                ColumnarValue::Array(array)
-            }
-            _ => {
-                return Err(DataFusionError::Execution(
-                    "This UDF only supports f64".to_string(),
-                ))
-            }
-        };
-        Ok(result)
-    });
+        // 1. cast both arguments to f64. These casts MUST be aligned with the 
signature or this function panics!
+        let base = &args[0]
+            .as_any()
+            .downcast_ref::<Float64Array>()
+            .expect("cast failed");
+        let exponent = &args[1]
+            .as_any()
+            .downcast_ref::<Float64Array>()
+            .expect("cast failed");
+
+        // this is guaranteed by DataFusion. We place it just to make it 
obvious.
+        assert_eq!(exponent.len(), base.len());
+
+        // 2. perform the computation
+        let array = base
+            .iter()
+            .zip(exponent.iter())
+            .map(|(base, exponent)| {
+                match (base, exponent) {
+                    // in arrow, any value can be null.
+                    // Here we decide to make our UDF to return null when 
either base or exponent is null.
+                    (Some(base), Some(exponent)) => Some(base.powf(exponent)),
+                    _ => None,
+                }
+            })
+            .collect::<Float64Array>();
+
+        // `Ok` because no error occurred during the calculation (we should 
add one if exponent was [0, 1[ and the base < 0 because that panics!)
+        // `Arc` because arrays are immutable, thread-safe, trait objects.
+        Ok(Arc::new(array) as ArrayRef)
+    };
+    // the function above expects an `ArrayRef`, but DataFusion may pass a 
scalar to a UDF.
+    // thus, we use `make_scalar_function` to decorare the closure so that it 
can handle both Arrays and Scalar values.

Review comment:
       👍 

##########
File path: rust/datafusion/examples/simple_udf.rs
##########
@@ -54,50 +58,76 @@ fn create_context() -> Result<ExecutionContext> {
     Ok(ctx)
 }
 
+// a small utility function to compute pow(base, exponent)
+fn maybe_pow(base: &Option<f64>, exponent: &Option<f64>) -> Option<f64> {
+    match (base, exponent) {
+        // in arrow, any value can be null.
+        // Here we decide to make our UDF to return null when either base or 
exponent is null.
+        (Some(base), Some(exponent)) => Some(base.powf(*exponent)),
+        _ => None,
+    }
+}
+
+fn pow_array(base: &dyn Array, exponent: &dyn Array) -> Result<ArrayRef> {
+    // 1. cast both arguments to f64. These casts MUST be aligned with the 
signature or this function panics!
+    let base = base
+        .as_any()
+        .downcast_ref::<Float64Array>()
+        .expect("cast failed");
+    let exponent = exponent
+        .as_any()
+        .downcast_ref::<Float64Array>()
+        .expect("cast failed");
+
+    // this is guaranteed by DataFusion. We place it just to make it obvious.
+    assert_eq!(exponent.len(), base.len());
+
+    // 2. perform the computation
+    let array = base
+        .iter()
+        .zip(exponent.iter())
+        .map(|(base, exponent)| maybe_pow(&base, &exponent))
+        .collect::<Float64Array>();
+
+    // `Ok` because no error occurred during the calculation (we should add 
one if exponent was [0, 1[ and the base < 0 because that panics!)

Review comment:
       cool -- sorry I guess I am used to seeing an open interval  using a `)` 
--so in this case something like `[0, 1)` to represent `0 <= exponent < 1` 
(e.g. [here](https://en.wikipedia.org/wiki/Interval_(mathematics)#Terminology)




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to