martin-g commented on code in PR #19888:
URL: https://github.com/apache/datafusion/pull/19888#discussion_r2707890087


##########
datafusion/functions/benches/cot.rs:
##########
@@ -85,6 +86,51 @@ fn criterion_benchmark(c: &mut Criterion) {
                 )
             })
         });
+
+        // Scalar benchmarks
+        let scalar_f32_args =
+            vec![ColumnarValue::Scalar(ScalarValue::Float32(Some(1.0)))];
+        let scalar_f32_arg_fields =
+            vec![Field::new("a", DataType::Float32, false).into()];
+        let return_field_f32 = Field::new("f", DataType::Float32, 
false).into();
+
+        c.bench_function(&format!("cot f32 scalar: {size}"), |b| {
+            b.iter(|| {
+                black_box(
+                    cot_fn
+                        .invoke_with_args(ScalarFunctionArgs {
+                            args: scalar_f32_args.clone(),
+                            arg_fields: scalar_f32_arg_fields.clone(),
+                            number_rows: 1,

Review Comment:
   ```suggestion
                               number_rows: size,
   ```
   Currently the `size` variable is used only for the label but the args seem 
to be always the same.



##########
datafusion/functions/src/math/cot.rs:
##########
@@ -129,54 +152,93 @@ fn compute_cot64(x: f64) -> f64 {
 
 #[cfg(test)]
 mod test {
-    use crate::math::cot::cot;
+    use std::sync::Arc;
+
     use arrow::array::{ArrayRef, Float32Array, Float64Array};
+    use arrow::datatypes::{DataType, Field};
     use datafusion_common::cast::{as_float32_array, as_float64_array};
-    use std::sync::Arc;
+    use datafusion_common::config::ConfigOptions;
+    use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
+
+    use crate::math::cot::CotFunc;
 
     #[test]
     fn test_cot_f32() {
-        let args: Vec<ArrayRef> =
-            vec![Arc::new(Float32Array::from(vec![12.1, 30.0, 90.0, -30.0]))];
-        let result = cot(&args).expect("failed to initialize function cot");
-        let floats =
-            as_float32_array(&result).expect("failed to initialize function 
cot");
-
-        let expected = Float32Array::from(vec![
-            -1.986_460_4,
-            -0.156_119_96,
-            -0.501_202_8,
-            0.156_119_96,
-        ]);
-
-        let eps = 1e-6;
-        assert_eq!(floats.len(), 4);
-        assert!((floats.value(0) - expected.value(0)).abs() < eps);
-        assert!((floats.value(1) - expected.value(1)).abs() < eps);
-        assert!((floats.value(2) - expected.value(2)).abs() < eps);
-        assert!((floats.value(3) - expected.value(3)).abs() < eps);
+        let array = Arc::new(Float32Array::from(vec![12.1, 30.0, 90.0, 
-30.0]));
+        let arg_fields = vec![Field::new("a", DataType::Float32, 
false).into()];
+        let args = ScalarFunctionArgs {
+            args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)],
+            arg_fields,
+            number_rows: array.len(),
+            return_field: Field::new("f", DataType::Float32, true).into(),
+            config_options: Arc::new(ConfigOptions::default()),
+        };
+        let result = CotFunc::new()
+            .invoke_with_args(args)
+            .expect("failed to initialize function cot");
+
+        match result {
+            ColumnarValue::Array(arr) => {
+                let floats = as_float32_array(&arr)
+                    .expect("failed to convert result to a Float32Array");
+
+                let expected = Float32Array::from(vec![
+                    -1.986_460_4,
+                    -0.156_119_96,
+                    -0.501_202_8,
+                    0.156_119_96,
+                ]);
+
+                let eps = 1e-6;
+                assert_eq!(floats.len(), 4);
+                assert!((floats.value(0) - expected.value(0)).abs() < eps);
+                assert!((floats.value(1) - expected.value(1)).abs() < eps);
+                assert!((floats.value(2) - expected.value(2)).abs() < eps);
+                assert!((floats.value(3) - expected.value(3)).abs() < eps);
+            }
+            ColumnarValue::Scalar(_) => {
+                panic!("Expected an array value")
+            }
+        }
     }
 
     #[test]
     fn test_cot_f64() {
-        let args: Vec<ArrayRef> =
-            vec![Arc::new(Float64Array::from(vec![12.1, 30.0, 90.0, -30.0]))];
-        let result = cot(&args).expect("failed to initialize function cot");
-        let floats =
-            as_float64_array(&result).expect("failed to initialize function 
cot");
-
-        let expected = Float64Array::from(vec![
-            -1.986_458_685_881_4,
-            -0.156_119_952_161_6,
-            -0.501_202_783_380_1,
-            0.156_119_952_161_6,
-        ]);
-
-        let eps = 1e-12;
-        assert_eq!(floats.len(), 4);
-        assert!((floats.value(0) - expected.value(0)).abs() < eps);
-        assert!((floats.value(1) - expected.value(1)).abs() < eps);
-        assert!((floats.value(2) - expected.value(2)).abs() < eps);
-        assert!((floats.value(3) - expected.value(3)).abs() < eps);
+        let array = Arc::new(Float64Array::from(vec![12.1, 30.0, 90.0, 
-30.0]));
+        let arg_fields = vec![Field::new("a", DataType::Float64, 
false).into()];
+        let args = ScalarFunctionArgs {
+            args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)],
+            arg_fields,
+            number_rows: array.len(),
+            return_field: Field::new("f", DataType::Float64, true).into(),
+            config_options: Arc::new(ConfigOptions::default()),
+        };
+        let result = CotFunc::new()
+            .invoke_with_args(args)
+            .expect("failed to initialize function cot");
+
+        match result {
+            ColumnarValue::Array(arr) => {
+                let floats = as_float64_array(&arr)
+                    .expect("failed to convert result to a Float64Array");
+
+                let expected = Float64Array::from(vec![
+                    -1.986_458_685_881_4,
+                    -0.156_119_952_161_6,
+                    -0.501_202_783_380_1,
+                    0.156_119_952_161_6,
+                ]);
+
+                let eps = 1e-12;
+                assert_eq!(floats.len(), 4);
+                assert!((floats.value(0) - expected.value(0)).abs() < eps);
+                assert!((floats.value(1) - expected.value(1)).abs() < eps);
+                assert!((floats.value(2) - expected.value(2)).abs() < eps);
+                assert!((floats.value(3) - expected.value(3)).abs() < eps);
+            }
+            ColumnarValue::Scalar(_) => {
+                panic!("Expected an array value")
+            }
+        }

Review Comment:
   There are no tests for the Scalar input/output (the fast path).
   Also it would be good to add tests for inputs like `NULL`, `0.0` and 
`f64::consts::Pi`



##########
datafusion/functions/src/math/cot.rs:
##########
@@ -96,24 +96,47 @@ impl ScalarUDFImpl for CotFunc {
     }
 
     fn invoke_with_args(&self, args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
-        make_scalar_function(cot, vec![])(&args.args)
-    }
-}
-
-///cot SQL function
-fn cot(args: &[ArrayRef]) -> Result<ArrayRef> {
-    match args[0].data_type() {
-        Float64 => Ok(Arc::new(
-            args[0]
-                .as_primitive::<Float64Type>()
-                .unary::<_, Float64Type>(|x: f64| compute_cot64(x)),
-        ) as ArrayRef),
-        Float32 => Ok(Arc::new(
-            args[0]
-                .as_primitive::<Float32Type>()
-                .unary::<_, Float32Type>(|x: f32| compute_cot32(x)),
-        ) as ArrayRef),
-        other => exec_err!("Unsupported data type {other:?} for function cot"),
+        let return_type = args.return_type().clone();

Review Comment:
   This variable is used just once - it could be moved inside `if 
scalar.is_null() {` to avoid the cloning if not used.



##########
datafusion/functions/benches/cot.rs:
##########
@@ -85,6 +86,51 @@ fn criterion_benchmark(c: &mut Criterion) {
                 )
             })
         });
+
+        // Scalar benchmarks
+        let scalar_f32_args =
+            vec![ColumnarValue::Scalar(ScalarValue::Float32(Some(1.0)))];
+        let scalar_f32_arg_fields =
+            vec![Field::new("a", DataType::Float32, false).into()];
+        let return_field_f32 = Field::new("f", DataType::Float32, 
false).into();
+
+        c.bench_function(&format!("cot f32 scalar: {size}"), |b| {
+            b.iter(|| {
+                black_box(
+                    cot_fn
+                        .invoke_with_args(ScalarFunctionArgs {
+                            args: scalar_f32_args.clone(),
+                            arg_fields: scalar_f32_arg_fields.clone(),
+                            number_rows: 1,
+                            return_field: Arc::clone(&return_field_f32),
+                            config_options: Arc::clone(&config_options),
+                        })
+                        .unwrap(),
+                )
+            })
+        });
+
+        let scalar_f64_args =
+            vec![ColumnarValue::Scalar(ScalarValue::Float64(Some(1.0)))];
+        let scalar_f64_arg_fields =
+            vec![Field::new("a", DataType::Float64, false).into()];
+        let return_field_f64 = Field::new("f", DataType::Float64, 
false).into();
+
+        c.bench_function(&format!("cot f64 scalar: {size}"), |b| {
+            b.iter(|| {
+                black_box(
+                    cot_fn
+                        .invoke_with_args(ScalarFunctionArgs {
+                            args: scalar_f64_args.clone(),
+                            arg_fields: scalar_f64_arg_fields.clone(),
+                            number_rows: 1,

Review Comment:
   ```suggestion
                               number_rows: size,
   ```



##########
datafusion/functions/src/math/cot.rs:
##########
@@ -96,24 +96,47 @@ impl ScalarUDFImpl for CotFunc {
     }
 
     fn invoke_with_args(&self, args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
-        make_scalar_function(cot, vec![])(&args.args)
-    }
-}
-
-///cot SQL function
-fn cot(args: &[ArrayRef]) -> Result<ArrayRef> {
-    match args[0].data_type() {
-        Float64 => Ok(Arc::new(
-            args[0]
-                .as_primitive::<Float64Type>()
-                .unary::<_, Float64Type>(|x: f64| compute_cot64(x)),
-        ) as ArrayRef),
-        Float32 => Ok(Arc::new(
-            args[0]
-                .as_primitive::<Float32Type>()
-                .unary::<_, Float32Type>(|x: f32| compute_cot32(x)),
-        ) as ArrayRef),
-        other => exec_err!("Unsupported data type {other:?} for function cot"),
+        let return_type = args.return_type().clone();
+        let [arg] = take_function_args(self.name(), args.args)?;
+
+        match arg {
+            ColumnarValue::Scalar(scalar) => {
+                if scalar.is_null() {
+                    return ColumnarValue::Scalar(ScalarValue::Null)
+                        .cast_to(&return_type, None);
+                }
+
+                match scalar {
+                    ScalarValue::Float64(Some(v)) => Ok(ColumnarValue::Scalar(
+                        ScalarValue::Float64(Some(compute_cot64(v))),
+                    )),
+                    ScalarValue::Float32(Some(v)) => Ok(ColumnarValue::Scalar(
+                        ScalarValue::Float32(Some(compute_cot32(v))),
+                    )),
+                    _ => {
+                        internal_err!(
+                            "Unexpected scalar type for cot: {:?}",
+                            scalar.data_type()
+                        )
+                    }
+                }
+            }
+            ColumnarValue::Array(array) => match array.data_type() {
+                Float64 => Ok(ColumnarValue::Array(Arc::new(
+                    array
+                        .as_primitive::<Float64Type>()
+                        .unary::<_, Float64Type>(compute_cot64),
+                ))),
+                Float32 => Ok(ColumnarValue::Array(Arc::new(
+                    array
+                        .as_primitive::<Float32Type>()
+                        .unary::<_, Float32Type>(compute_cot32),
+                ))),
+                other => {
+                    internal_err!("Unexpected data type {other:?} for function 
cot")

Review Comment:
   Is it intentional to use `internal_err!()` instead of `exec_err!()` (old 
line 116) ?!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to