This is an automated email from the ASF dual-hosted git repository.

liukun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 4d5639e2a4 [Functions] Support Arithmetic function COT() (#6925)
4d5639e2a4 is described below

commit 4d5639e2a46fa6c2a74e4170f576cd70402bb830
Author: Syleechan <[email protected]>
AuthorDate: Fri Jul 14 16:27:19 2023 +0800

    [Functions] Support Arithmetic function COT() (#6925)
---
 datafusion/expr/src/built_in_function.rs         | 10 ++-
 datafusion/expr/src/expr_fn.rs                   |  2 +
 datafusion/physical-expr/src/functions.rs        |  3 +
 datafusion/physical-expr/src/math_expressions.rs | 79 ++++++++++++++++++++++++
 datafusion/proto/proto/datafusion.proto          |  1 +
 datafusion/proto/src/generated/pbjson.rs         |  3 +
 datafusion/proto/src/generated/prost.rs          |  3 +
 datafusion/proto/src/logical_plan/from_proto.rs  |  6 +-
 datafusion/proto/src/logical_plan/to_proto.rs    |  1 +
 9 files changed, 104 insertions(+), 4 deletions(-)

diff --git a/datafusion/expr/src/built_in_function.rs 
b/datafusion/expr/src/built_in_function.rs
index 9acb82d47b..74561d9fd7 100644
--- a/datafusion/expr/src/built_in_function.rs
+++ b/datafusion/expr/src/built_in_function.rs
@@ -111,6 +111,8 @@ pub enum BuiltinScalarFunction {
     Tanh,
     /// trunc
     Trunc,
+    /// cot
+    Cot,
 
     // array functions
     /// array_append
@@ -322,6 +324,7 @@ impl BuiltinScalarFunction {
             BuiltinScalarFunction::Sinh => Volatility::Immutable,
             BuiltinScalarFunction::Sqrt => Volatility::Immutable,
             BuiltinScalarFunction::Cbrt => Volatility::Immutable,
+            BuiltinScalarFunction::Cot => Volatility::Immutable,
             BuiltinScalarFunction::Tan => Volatility::Immutable,
             BuiltinScalarFunction::Tanh => Volatility::Immutable,
             BuiltinScalarFunction::Trunc => Volatility::Immutable,
@@ -764,7 +767,8 @@ impl BuiltinScalarFunction {
             | BuiltinScalarFunction::Cbrt
             | BuiltinScalarFunction::Tan
             | BuiltinScalarFunction::Tanh
-            | BuiltinScalarFunction::Trunc => match input_expr_types[0] {
+            | BuiltinScalarFunction::Trunc
+            | BuiltinScalarFunction::Cot => match input_expr_types[0] {
                 Float32 => Ok(Float32),
                 _ => Ok(Float64),
             },
@@ -1112,7 +1116,8 @@ impl BuiltinScalarFunction {
             | BuiltinScalarFunction::Sqrt
             | BuiltinScalarFunction::Tan
             | BuiltinScalarFunction::Tanh
-            | BuiltinScalarFunction::Trunc => {
+            | BuiltinScalarFunction::Trunc
+            | BuiltinScalarFunction::Cot => {
                 // math expressions expect 1 argument of type f64 or f32
                 // priority is given to f64 because e.g. `sqrt(1i32)` is in IR 
(real numbers) and thus we
                 // return the best approximation for it (in f64).
@@ -1142,6 +1147,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static 
[&'static str] {
         BuiltinScalarFunction::Cbrt => &["cbrt"],
         BuiltinScalarFunction::Ceil => &["ceil"],
         BuiltinScalarFunction::Cos => &["cos"],
+        BuiltinScalarFunction::Cot => &["cot"],
         BuiltinScalarFunction::Cosh => &["cosh"],
         BuiltinScalarFunction::Degrees => &["degrees"],
         BuiltinScalarFunction::Exp => &["exp"],
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index 480ea5d608..4773737674 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -476,6 +476,7 @@ scalar_expr!(Cbrt, cbrt, num, "cube root of a number");
 scalar_expr!(Sin, sin, num, "sine");
 scalar_expr!(Cos, cos, num, "cosine");
 scalar_expr!(Tan, tan, num, "tangent");
+scalar_expr!(Cot, cot, num, "cotangent");
 scalar_expr!(Sinh, sinh, num, "hyperbolic sine");
 scalar_expr!(Cosh, cosh, num, "hyperbolic cosine");
 scalar_expr!(Tanh, tanh, num, "hyperbolic tangent");
@@ -912,6 +913,7 @@ mod test {
         test_unary_scalar_expr!(Sin, sin);
         test_unary_scalar_expr!(Cos, cos);
         test_unary_scalar_expr!(Tan, tan);
+        test_unary_scalar_expr!(Cot, cot);
         test_unary_scalar_expr!(Sinh, sinh);
         test_unary_scalar_expr!(Cosh, cosh);
         test_unary_scalar_expr!(Tanh, tanh);
diff --git a/datafusion/physical-expr/src/functions.rs 
b/datafusion/physical-expr/src/functions.rs
index 6a81042b7c..a92d4335d4 100644
--- a/datafusion/physical-expr/src/functions.rs
+++ b/datafusion/physical-expr/src/functions.rs
@@ -403,6 +403,9 @@ pub fn create_physical_fun(
         BuiltinScalarFunction::Log => {
             Arc::new(|args| make_scalar_function(math_expressions::log)(args))
         }
+        BuiltinScalarFunction::Cot => {
+            Arc::new(|args| make_scalar_function(math_expressions::cot)(args))
+        }
 
         // array functions
         BuiltinScalarFunction::ArrayAppend => {
diff --git a/datafusion/physical-expr/src/math_expressions.rs 
b/datafusion/physical-expr/src/math_expressions.rs
index fbfb82814e..9a4653c8a0 100644
--- a/datafusion/physical-expr/src/math_expressions.rs
+++ b/datafusion/physical-expr/src/math_expressions.rs
@@ -497,6 +497,39 @@ pub fn log(args: &[ArrayRef]) -> Result<ArrayRef> {
     }
 }
 
+///cot SQL function
+pub fn cot(args: &[ArrayRef]) -> Result<ArrayRef> {
+    match args[0].data_type() {
+        DataType::Float64 => Ok(Arc::new(make_function_scalar_inputs!(
+            &args[0],
+            "x",
+            Float64Array,
+            { compute_cot64 }
+        )) as ArrayRef),
+
+        DataType::Float32 => Ok(Arc::new(make_function_scalar_inputs!(
+            &args[0],
+            "x",
+            Float32Array,
+            { compute_cot32 }
+        )) as ArrayRef),
+
+        other => Err(DataFusionError::Internal(format!(
+            "Unsupported data type {other:?} for function cot"
+        ))),
+    }
+}
+
+fn compute_cot32(x: f32) -> f32 {
+    let a = f32::tan(x);
+    1.0 / a
+}
+
+fn compute_cot64(x: f64) -> f64 {
+    let a = f64::tan(x);
+    1.0 / a
+}
+
 #[cfg(test)]
 mod tests {
 
@@ -739,4 +772,50 @@ mod tests {
         assert_eq!(ints.value(2), 75);
         assert_eq!(ints.value(3), 16);
     }
+
+    #[test]
+    fn test_cot_f32() {
+        let args: Vec<ArrayRef> =
+            vec![Arc::new(Float32Array::from(vec![12.1, 30.0, 90.0, -30.0]))];
+        let result = cot(&args).expect("failed to initialize function cot");
+        let floats =
+            as_float32_array(&result).expect("failed to initialize function 
cot");
+
+        let expected = Float32Array::from(vec![
+            -1.986_460_4,
+            -0.156_119_96,
+            -0.501_202_8,
+            0.156_119_96,
+        ]);
+
+        let eps = 1e-6;
+        assert_eq!(floats.len(), 4);
+        assert!((floats.value(0) - expected.value(0)).abs() < eps);
+        assert!((floats.value(1) - expected.value(1)).abs() < eps);
+        assert!((floats.value(2) - expected.value(2)).abs() < eps);
+        assert!((floats.value(3) - expected.value(3)).abs() < eps);
+    }
+
+    #[test]
+    fn test_cot_f64() {
+        let args: Vec<ArrayRef> =
+            vec![Arc::new(Float64Array::from(vec![12.1, 30.0, 90.0, -30.0]))];
+        let result = cot(&args).expect("failed to initialize function cot");
+        let floats =
+            as_float64_array(&result).expect("failed to initialize function 
cot");
+
+        let expected = Float64Array::from(vec![
+            -1.986_458_685_881_4,
+            -0.156_119_952_161_6,
+            -0.501_202_783_380_1,
+            0.156_119_952_161_6,
+        ]);
+
+        let eps = 1e-12;
+        assert_eq!(floats.len(), 4);
+        assert!((floats.value(0) - expected.value(0)).abs() < eps);
+        assert!((floats.value(1) - expected.value(1)).abs() < eps);
+        assert!((floats.value(2) - expected.value(2)).abs() < eps);
+        assert!((floats.value(3) - expected.value(3)).abs() < eps);
+    }
 }
diff --git a/datafusion/proto/proto/datafusion.proto 
b/datafusion/proto/proto/datafusion.proto
index 89bca57cf3..b26d18947f 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -566,6 +566,7 @@ enum ScalarFunction {
   ArrayContains = 100;
   Encode = 101;
   Decode = 102;
+  Cot = 103;
 }
 
 message ScalarFunctionNode {
diff --git a/datafusion/proto/src/generated/pbjson.rs 
b/datafusion/proto/src/generated/pbjson.rs
index 590b462ad8..09e85f11a7 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -18064,6 +18064,7 @@ impl serde::Serialize for ScalarFunction {
             Self::ArrayContains => "ArrayContains",
             Self::Encode => "Encode",
             Self::Decode => "Decode",
+            Self::Cot => "Cot",
         };
         serializer.serialize_str(variant)
     }
@@ -18178,6 +18179,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction {
             "ArrayContains",
             "Encode",
             "Decode",
+            "Cot",
         ];
 
         struct GeneratedVisitor;
@@ -18323,6 +18325,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction {
                     "ArrayContains" => Ok(ScalarFunction::ArrayContains),
                     "Encode" => Ok(ScalarFunction::Encode),
                     "Decode" => Ok(ScalarFunction::Decode),
+                    "Cot" => Ok(ScalarFunction::Cot),
                     _ => Err(serde::de::Error::unknown_variant(value, FIELDS)),
                 }
             }
diff --git a/datafusion/proto/src/generated/prost.rs 
b/datafusion/proto/src/generated/prost.rs
index 251760f090..516aa325bc 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -2265,6 +2265,7 @@ pub enum ScalarFunction {
     ArrayContains = 100,
     Encode = 101,
     Decode = 102,
+    Cot = 103,
 }
 impl ScalarFunction {
     /// String value of the enum field names used in the ProtoBuf definition.
@@ -2376,6 +2377,7 @@ impl ScalarFunction {
             ScalarFunction::ArrayContains => "ArrayContains",
             ScalarFunction::Encode => "Encode",
             ScalarFunction::Decode => "Decode",
+            ScalarFunction::Cot => "Cot",
         }
     }
     /// Creates an enum from field names used in the ProtoBuf definition.
@@ -2484,6 +2486,7 @@ impl ScalarFunction {
             "ArrayContains" => Some(Self::ArrayContains),
             "Encode" => Some(Self::Encode),
             "Decode" => Some(Self::Decode),
+            "Cot" => Some(Self::Cot),
             _ => None,
         }
     }
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs 
b/datafusion/proto/src/logical_plan/from_proto.rs
index 1b48364ad4..8b70480e7d 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -40,8 +40,8 @@ use datafusion_expr::{
     array_fill, array_length, array_ndims, array_position, array_positions,
     array_prepend, array_remove, array_replace, array_to_string, ascii, asin, 
asinh,
     atan, atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil, 
character_length,
-    chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, current_date, 
current_time,
-    date_bin, date_part, date_trunc, degrees, digest, exp,
+    chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, current_date,
+    current_time, date_bin, date_part, date_trunc, degrees, digest, exp,
     expr::{self, InList, Sort, WindowFunction},
     factorial, floor, from_unixtime, gcd, lcm, left, ln, log, log10, log2,
     logical_plan::{PlanType, StringifiedPlan},
@@ -417,6 +417,7 @@ impl From<&protobuf::ScalarFunction> for 
BuiltinScalarFunction {
             ScalarFunction::Sin => Self::Sin,
             ScalarFunction::Cos => Self::Cos,
             ScalarFunction::Tan => Self::Tan,
+            ScalarFunction::Cot => Self::Cot,
             ScalarFunction::Asin => Self::Asin,
             ScalarFunction::Acos => Self::Acos,
             ScalarFunction::Atan => Self::Atan,
@@ -1473,6 +1474,7 @@ pub fn parse_expr(
                 )),
                 ScalarFunction::CurrentDate => Ok(current_date()),
                 ScalarFunction::CurrentTime => Ok(current_time()),
+                ScalarFunction::Cot => Ok(cot(parse_expr(&args[0], 
registry)?)),
                 _ => Err(proto_error(
                     "Protobuf deserialization error: Unsupported scalar 
function",
                 )),
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs 
b/datafusion/proto/src/logical_plan/to_proto.rs
index 8665ca00c3..30bb30950c 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -1344,6 +1344,7 @@ impl TryFrom<&BuiltinScalarFunction> for 
protobuf::ScalarFunction {
             BuiltinScalarFunction::Sin => Self::Sin,
             BuiltinScalarFunction::Cos => Self::Cos,
             BuiltinScalarFunction::Tan => Self::Tan,
+            BuiltinScalarFunction::Cot => Self::Cot,
             BuiltinScalarFunction::Sinh => Self::Sinh,
             BuiltinScalarFunction::Cosh => Self::Cosh,
             BuiltinScalarFunction::Tanh => Self::Tanh,

Reply via email to