alamb commented on a change in pull request #9376:
URL: https://github.com/apache/arrow/pull/9376#discussion_r568532324
##########
File path: rust/datafusion/src/physical_plan/datetime_expressions.rs
##########
@@ -167,152 +175,166 @@ fn naive_datetime_to_timestamp(s: &str, datetime:
NaiveDateTime) -> Result<i64>
}
}
-/// convert an array of strings into `Timestamp(Nanosecond, None)`
-pub fn to_timestamp(args: &[ArrayRef]) -> Result<TimestampNanosecondArray> {
- let num_rows = args[0].len();
- let string_args =
- &args[0]
- .as_any()
- .downcast_ref::<StringArray>()
- .ok_or_else(|| {
- DataFusionError::Internal(
- "could not cast to_timestamp input to
StringArray".to_string(),
- )
- })?;
-
- let result = (0..num_rows)
- .map(|i| {
- if string_args.is_null(i) {
- // NB: Since we use the same null bitset as the input,
- // the output for this value will be ignored, but we
- // need some value in the array we are building.
- Ok(0)
- } else {
- string_to_timestamp_nanos(string_args.value(i))
+pub(crate) fn unary_string_to_primitive_function<'a, T, O, F>(
+ args: &[&'a dyn Array],
+ op: F,
+ name: &str,
+) -> Result<PrimitiveArray<O>>
+where
+ O: ArrowPrimitiveType,
+ T: StringOffsetSizeTrait,
+ F: Fn(&'a str) -> Result<O::Native>,
+{
+ if args.len() != 1 {
+ return Err(DataFusionError::Internal(format!(
+ "{:?} args were supplied but {} takes exactly one argument",
+ args.len(),
+ name,
+ )));
+ }
+
+ let array = args[0]
+ .as_any()
+ .downcast_ref::<GenericStringArray<T>>()
+ .unwrap();
+
+ // first map is the iterator, second is for the `Option<_>`
+ array.iter().map(|x| x.map(|x| op(x)).transpose()).collect()
+}
+
+fn handle<'a, O, F, S>(
Review comment:
makes sense -- thank you for the explination
##########
File path: rust/datafusion/src/physical_plan/datetime_expressions.rs
##########
@@ -167,152 +175,166 @@ fn naive_datetime_to_timestamp(s: &str, datetime:
NaiveDateTime) -> Result<i64>
}
}
-/// convert an array of strings into `Timestamp(Nanosecond, None)`
-pub fn to_timestamp(args: &[ArrayRef]) -> Result<TimestampNanosecondArray> {
- let num_rows = args[0].len();
- let string_args =
- &args[0]
- .as_any()
- .downcast_ref::<StringArray>()
- .ok_or_else(|| {
- DataFusionError::Internal(
- "could not cast to_timestamp input to
StringArray".to_string(),
- )
- })?;
-
- let result = (0..num_rows)
- .map(|i| {
- if string_args.is_null(i) {
- // NB: Since we use the same null bitset as the input,
- // the output for this value will be ignored, but we
- // need some value in the array we are building.
- Ok(0)
- } else {
- string_to_timestamp_nanos(string_args.value(i))
+pub(crate) fn unary_string_to_primitive_function<'a, T, O, F>(
+ args: &[&'a dyn Array],
+ op: F,
+ name: &str,
+) -> Result<PrimitiveArray<O>>
+where
+ O: ArrowPrimitiveType,
+ T: StringOffsetSizeTrait,
+ F: Fn(&'a str) -> Result<O::Native>,
+{
+ if args.len() != 1 {
+ return Err(DataFusionError::Internal(format!(
+ "{:?} args were supplied but {} takes exactly one argument",
+ args.len(),
+ name,
+ )));
+ }
+
+ let array = args[0]
+ .as_any()
+ .downcast_ref::<GenericStringArray<T>>()
+ .unwrap();
+
+ // first map is the iterator, second is for the `Option<_>`
+ array.iter().map(|x| x.map(|x| op(x)).transpose()).collect()
+}
+
+fn handle<'a, O, F, S>(
Review comment:
makes sense -- thank you for the explanation
##########
File path: rust/datafusion/examples/simple_udf.rs
##########
@@ -58,77 +54,54 @@ fn create_context() -> Result<ExecutionContext> {
Ok(ctx)
}
-// a small utility function to compute pow(base, exponent)
-fn maybe_pow(base: &Option<f64>, exponent: &Option<f64>) -> Option<f64> {
- match (base, exponent) {
- // in arrow, any value can be null.
- // Here we decide to make our UDF to return null when either base or
exponent is null.
- (Some(base), Some(exponent)) => Some(base.powf(*exponent)),
- _ => None,
- }
-}
-
-fn pow_array(base: &dyn Array, exponent: &dyn Array) -> Result<ArrayRef> {
- // 1. cast both arguments to f64. These casts MUST be aligned with the
signature or this function panics!
- let base = base
- .as_any()
- .downcast_ref::<Float64Array>()
- .expect("cast failed");
- let exponent = exponent
- .as_any()
- .downcast_ref::<Float64Array>()
- .expect("cast failed");
-
- // this is guaranteed by DataFusion. We place it just to make it obvious.
- assert_eq!(exponent.len(), base.len());
-
- // 2. perform the computation
- let array = base
- .iter()
- .zip(exponent.iter())
- .map(|(base, exponent)| maybe_pow(&base, &exponent))
- .collect::<Float64Array>();
-
- // `Ok` because no error occurred during the calculation (we should add
one if exponent was [0, 1[ and the base < 0 because that panics!)
- // `Arc` because arrays are immutable, thread-safe, trait objects.
- Ok(Arc::new(array))
-}
-
/// In this example we will declare a single-type, single return type UDF that
exponentiates f64, a^b
#[tokio::main]
async fn main() -> Result<()> {
let mut ctx = create_context()?;
// First, declare the actual implementation of the calculation
- let pow: ScalarFunctionImplementation = Arc::new(|args: &[ColumnarValue]| {
- // in DataFusion, all `args` and output are `ColumnarValue`, an enum
of either a scalar or a dynamically-typed array.
- // we can cater for both, or document that the UDF only supports some
variants.
- // here we will assume that al
+ let pow = |args: &[ArrayRef]| {
+ // in DataFusion, all `args` and output are dynamically-typed arrays,
which means that we need to:
// 1. cast the values to the type we want
// 2. perform the computation for every element in the array (using a
loop or SIMD) and construct the result
// this is guaranteed by DataFusion based on the function's signature.
assert_eq!(args.len(), 2);
- let (base, exponent) = (&args[0], &args[1]);
-
- let result = match (base, exponent) {
- (
- ColumnarValue::Scalar(ScalarValue::Float64(base)),
- ColumnarValue::Scalar(ScalarValue::Float64(exponent)),
- ) => ColumnarValue::Scalar(ScalarValue::Float64(maybe_pow(base,
exponent))),
- (ColumnarValue::Array(base), ColumnarValue::Array(exponent)) => {
- let array = pow_array(base.as_ref(), exponent.as_ref())?;
- ColumnarValue::Array(array)
- }
- _ => {
- return Err(DataFusionError::Execution(
- "This UDF only supports f64".to_string(),
- ))
- }
- };
- Ok(result)
- });
+ // 1. cast both arguments to f64. These casts MUST be aligned with the
signature or this function panics!
+ let base = &args[0]
+ .as_any()
+ .downcast_ref::<Float64Array>()
+ .expect("cast failed");
+ let exponent = &args[1]
+ .as_any()
+ .downcast_ref::<Float64Array>()
+ .expect("cast failed");
+
+ // this is guaranteed by DataFusion. We place it just to make it
obvious.
+ assert_eq!(exponent.len(), base.len());
+
+ // 2. perform the computation
+ let array = base
+ .iter()
+ .zip(exponent.iter())
+ .map(|(base, exponent)| {
+ match (base, exponent) {
+ // in arrow, any value can be null.
+ // Here we decide to make our UDF to return null when
either base or exponent is null.
+ (Some(base), Some(exponent)) => Some(base.powf(exponent)),
+ _ => None,
+ }
+ })
+ .collect::<Float64Array>();
+
+ // `Ok` because no error occurred during the calculation (we should
add one if exponent was [0, 1[ and the base < 0 because that panics!)
+ // `Arc` because arrays are immutable, thread-safe, trait objects.
+ Ok(Arc::new(array) as ArrayRef)
+ };
+ // the function above expects an `ArrayRef`, but DataFusion may pass a
scalar to a UDF.
+ // thus, we use `make_scalar_function` to decorare the closure so that it
can handle both Arrays and Scalar values.
Review comment:
👍
##########
File path: rust/datafusion/examples/simple_udf.rs
##########
@@ -54,50 +58,76 @@ fn create_context() -> Result<ExecutionContext> {
Ok(ctx)
}
+// a small utility function to compute pow(base, exponent)
+fn maybe_pow(base: &Option<f64>, exponent: &Option<f64>) -> Option<f64> {
+ match (base, exponent) {
+ // in arrow, any value can be null.
+ // Here we decide to make our UDF to return null when either base or
exponent is null.
+ (Some(base), Some(exponent)) => Some(base.powf(*exponent)),
+ _ => None,
+ }
+}
+
+fn pow_array(base: &dyn Array, exponent: &dyn Array) -> Result<ArrayRef> {
+ // 1. cast both arguments to f64. These casts MUST be aligned with the
signature or this function panics!
+ let base = base
+ .as_any()
+ .downcast_ref::<Float64Array>()
+ .expect("cast failed");
+ let exponent = exponent
+ .as_any()
+ .downcast_ref::<Float64Array>()
+ .expect("cast failed");
+
+ // this is guaranteed by DataFusion. We place it just to make it obvious.
+ assert_eq!(exponent.len(), base.len());
+
+ // 2. perform the computation
+ let array = base
+ .iter()
+ .zip(exponent.iter())
+ .map(|(base, exponent)| maybe_pow(&base, &exponent))
+ .collect::<Float64Array>();
+
+ // `Ok` because no error occurred during the calculation (we should add
one if exponent was [0, 1[ and the base < 0 because that panics!)
Review comment:
cool -- sorry I guess I am used to seeing an open interval using a `)`
--so in this case something like `[0, 1)` to represent `0 <= exponent < 1`
(e.g. [here](https://en.wikipedia.org/wiki/Interval_(mathematics)#Terminology)
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]