kumarUjjawal commented on code in PR #19888:
URL: https://github.com/apache/datafusion/pull/19888#discussion_r2708799997
##########
datafusion/functions/src/math/cot.rs:
##########
@@ -129,54 +152,93 @@ fn compute_cot64(x: f64) -> f64 {
#[cfg(test)]
mod test {
- use crate::math::cot::cot;
+ use std::sync::Arc;
+
use arrow::array::{ArrayRef, Float32Array, Float64Array};
+ use arrow::datatypes::{DataType, Field};
use datafusion_common::cast::{as_float32_array, as_float64_array};
- use std::sync::Arc;
+ use datafusion_common::config::ConfigOptions;
+ use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
+
+ use crate::math::cot::CotFunc;
#[test]
fn test_cot_f32() {
- let args: Vec<ArrayRef> =
- vec![Arc::new(Float32Array::from(vec![12.1, 30.0, 90.0, -30.0]))];
- let result = cot(&args).expect("failed to initialize function cot");
- let floats =
- as_float32_array(&result).expect("failed to initialize function
cot");
-
- let expected = Float32Array::from(vec![
- -1.986_460_4,
- -0.156_119_96,
- -0.501_202_8,
- 0.156_119_96,
- ]);
-
- let eps = 1e-6;
- assert_eq!(floats.len(), 4);
- assert!((floats.value(0) - expected.value(0)).abs() < eps);
- assert!((floats.value(1) - expected.value(1)).abs() < eps);
- assert!((floats.value(2) - expected.value(2)).abs() < eps);
- assert!((floats.value(3) - expected.value(3)).abs() < eps);
+ let array = Arc::new(Float32Array::from(vec![12.1, 30.0, 90.0,
-30.0]));
+ let arg_fields = vec![Field::new("a", DataType::Float32,
false).into()];
+ let args = ScalarFunctionArgs {
+ args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)],
+ arg_fields,
+ number_rows: array.len(),
+ return_field: Field::new("f", DataType::Float32, true).into(),
+ config_options: Arc::new(ConfigOptions::default()),
+ };
+ let result = CotFunc::new()
+ .invoke_with_args(args)
+ .expect("failed to initialize function cot");
+
+ match result {
+ ColumnarValue::Array(arr) => {
+ let floats = as_float32_array(&arr)
+ .expect("failed to convert result to a Float32Array");
+
+ let expected = Float32Array::from(vec![
+ -1.986_460_4,
+ -0.156_119_96,
+ -0.501_202_8,
+ 0.156_119_96,
+ ]);
+
+ let eps = 1e-6;
+ assert_eq!(floats.len(), 4);
+ assert!((floats.value(0) - expected.value(0)).abs() < eps);
+ assert!((floats.value(1) - expected.value(1)).abs() < eps);
+ assert!((floats.value(2) - expected.value(2)).abs() < eps);
+ assert!((floats.value(3) - expected.value(3)).abs() < eps);
+ }
+ ColumnarValue::Scalar(_) => {
+ panic!("Expected an array value")
+ }
+ }
}
#[test]
fn test_cot_f64() {
- let args: Vec<ArrayRef> =
- vec![Arc::new(Float64Array::from(vec![12.1, 30.0, 90.0, -30.0]))];
- let result = cot(&args).expect("failed to initialize function cot");
- let floats =
- as_float64_array(&result).expect("failed to initialize function
cot");
-
- let expected = Float64Array::from(vec![
- -1.986_458_685_881_4,
- -0.156_119_952_161_6,
- -0.501_202_783_380_1,
- 0.156_119_952_161_6,
- ]);
-
- let eps = 1e-12;
- assert_eq!(floats.len(), 4);
- assert!((floats.value(0) - expected.value(0)).abs() < eps);
- assert!((floats.value(1) - expected.value(1)).abs() < eps);
- assert!((floats.value(2) - expected.value(2)).abs() < eps);
- assert!((floats.value(3) - expected.value(3)).abs() < eps);
+ let array = Arc::new(Float64Array::from(vec![12.1, 30.0, 90.0,
-30.0]));
+ let arg_fields = vec![Field::new("a", DataType::Float64,
false).into()];
+ let args = ScalarFunctionArgs {
+ args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)],
+ arg_fields,
+ number_rows: array.len(),
+ return_field: Field::new("f", DataType::Float64, true).into(),
+ config_options: Arc::new(ConfigOptions::default()),
+ };
+ let result = CotFunc::new()
+ .invoke_with_args(args)
+ .expect("failed to initialize function cot");
+
+ match result {
+ ColumnarValue::Array(arr) => {
+ let floats = as_float64_array(&arr)
+ .expect("failed to convert result to a Float64Array");
+
+ let expected = Float64Array::from(vec![
+ -1.986_458_685_881_4,
+ -0.156_119_952_161_6,
+ -0.501_202_783_380_1,
+ 0.156_119_952_161_6,
+ ]);
+
+ let eps = 1e-12;
+ assert_eq!(floats.len(), 4);
+ assert!((floats.value(0) - expected.value(0)).abs() < eps);
+ assert!((floats.value(1) - expected.value(1)).abs() < eps);
+ assert!((floats.value(2) - expected.value(2)).abs() < eps);
+ assert!((floats.value(3) - expected.value(3)).abs() < eps);
+ }
+ ColumnarValue::Scalar(_) => {
+ panic!("Expected an array value")
+ }
+ }
Review Comment:
Added unit tests for these. Thanks for the feedback.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]