This is an automated email from the ASF dual-hosted git repository.
liukun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 4d5639e2a4 [Functions] Support Arithmetic function COT() (#6925)
4d5639e2a4 is described below
commit 4d5639e2a46fa6c2a74e4170f576cd70402bb830
Author: Syleechan <[email protected]>
AuthorDate: Fri Jul 14 16:27:19 2023 +0800
[Functions] Support Arithmetic function COT() (#6925)
---
datafusion/expr/src/built_in_function.rs | 10 ++-
datafusion/expr/src/expr_fn.rs | 2 +
datafusion/physical-expr/src/functions.rs | 3 +
datafusion/physical-expr/src/math_expressions.rs | 79 ++++++++++++++++++++++++
datafusion/proto/proto/datafusion.proto | 1 +
datafusion/proto/src/generated/pbjson.rs | 3 +
datafusion/proto/src/generated/prost.rs | 3 +
datafusion/proto/src/logical_plan/from_proto.rs | 6 +-
datafusion/proto/src/logical_plan/to_proto.rs | 1 +
9 files changed, 104 insertions(+), 4 deletions(-)
diff --git a/datafusion/expr/src/built_in_function.rs
b/datafusion/expr/src/built_in_function.rs
index 9acb82d47b..74561d9fd7 100644
--- a/datafusion/expr/src/built_in_function.rs
+++ b/datafusion/expr/src/built_in_function.rs
@@ -111,6 +111,8 @@ pub enum BuiltinScalarFunction {
Tanh,
/// trunc
Trunc,
+ /// cot
+ Cot,
// array functions
/// array_append
@@ -322,6 +324,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Sinh => Volatility::Immutable,
BuiltinScalarFunction::Sqrt => Volatility::Immutable,
BuiltinScalarFunction::Cbrt => Volatility::Immutable,
+ BuiltinScalarFunction::Cot => Volatility::Immutable,
BuiltinScalarFunction::Tan => Volatility::Immutable,
BuiltinScalarFunction::Tanh => Volatility::Immutable,
BuiltinScalarFunction::Trunc => Volatility::Immutable,
@@ -764,7 +767,8 @@ impl BuiltinScalarFunction {
| BuiltinScalarFunction::Cbrt
| BuiltinScalarFunction::Tan
| BuiltinScalarFunction::Tanh
- | BuiltinScalarFunction::Trunc => match input_expr_types[0] {
+ | BuiltinScalarFunction::Trunc
+ | BuiltinScalarFunction::Cot => match input_expr_types[0] {
Float32 => Ok(Float32),
_ => Ok(Float64),
},
@@ -1112,7 +1116,8 @@ impl BuiltinScalarFunction {
| BuiltinScalarFunction::Sqrt
| BuiltinScalarFunction::Tan
| BuiltinScalarFunction::Tanh
- | BuiltinScalarFunction::Trunc => {
+ | BuiltinScalarFunction::Trunc
+ | BuiltinScalarFunction::Cot => {
// math expressions expect 1 argument of type f64 or f32
// priority is given to f64 because e.g. `sqrt(1i32)` is in IR
(real numbers) and thus we
// return the best approximation for it (in f64).
@@ -1142,6 +1147,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static
[&'static str] {
BuiltinScalarFunction::Cbrt => &["cbrt"],
BuiltinScalarFunction::Ceil => &["ceil"],
BuiltinScalarFunction::Cos => &["cos"],
+ BuiltinScalarFunction::Cot => &["cot"],
BuiltinScalarFunction::Cosh => &["cosh"],
BuiltinScalarFunction::Degrees => &["degrees"],
BuiltinScalarFunction::Exp => &["exp"],
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index 480ea5d608..4773737674 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -476,6 +476,7 @@ scalar_expr!(Cbrt, cbrt, num, "cube root of a number");
scalar_expr!(Sin, sin, num, "sine");
scalar_expr!(Cos, cos, num, "cosine");
scalar_expr!(Tan, tan, num, "tangent");
+scalar_expr!(Cot, cot, num, "cotangent");
scalar_expr!(Sinh, sinh, num, "hyperbolic sine");
scalar_expr!(Cosh, cosh, num, "hyperbolic cosine");
scalar_expr!(Tanh, tanh, num, "hyperbolic tangent");
@@ -912,6 +913,7 @@ mod test {
test_unary_scalar_expr!(Sin, sin);
test_unary_scalar_expr!(Cos, cos);
test_unary_scalar_expr!(Tan, tan);
+ test_unary_scalar_expr!(Cot, cot);
test_unary_scalar_expr!(Sinh, sinh);
test_unary_scalar_expr!(Cosh, cosh);
test_unary_scalar_expr!(Tanh, tanh);
diff --git a/datafusion/physical-expr/src/functions.rs
b/datafusion/physical-expr/src/functions.rs
index 6a81042b7c..a92d4335d4 100644
--- a/datafusion/physical-expr/src/functions.rs
+++ b/datafusion/physical-expr/src/functions.rs
@@ -403,6 +403,9 @@ pub fn create_physical_fun(
BuiltinScalarFunction::Log => {
Arc::new(|args| make_scalar_function(math_expressions::log)(args))
}
+ BuiltinScalarFunction::Cot => {
+ Arc::new(|args| make_scalar_function(math_expressions::cot)(args))
+ }
// array functions
BuiltinScalarFunction::ArrayAppend => {
diff --git a/datafusion/physical-expr/src/math_expressions.rs
b/datafusion/physical-expr/src/math_expressions.rs
index fbfb82814e..9a4653c8a0 100644
--- a/datafusion/physical-expr/src/math_expressions.rs
+++ b/datafusion/physical-expr/src/math_expressions.rs
@@ -497,6 +497,39 @@ pub fn log(args: &[ArrayRef]) -> Result<ArrayRef> {
}
}
+///cot SQL function
+pub fn cot(args: &[ArrayRef]) -> Result<ArrayRef> {
+ match args[0].data_type() {
+ DataType::Float64 => Ok(Arc::new(make_function_scalar_inputs!(
+ &args[0],
+ "x",
+ Float64Array,
+ { compute_cot64 }
+ )) as ArrayRef),
+
+ DataType::Float32 => Ok(Arc::new(make_function_scalar_inputs!(
+ &args[0],
+ "x",
+ Float32Array,
+ { compute_cot32 }
+ )) as ArrayRef),
+
+ other => Err(DataFusionError::Internal(format!(
+ "Unsupported data type {other:?} for function cot"
+ ))),
+ }
+}
+
+fn compute_cot32(x: f32) -> f32 {
+ let a = f32::tan(x);
+ 1.0 / a
+}
+
+fn compute_cot64(x: f64) -> f64 {
+ let a = f64::tan(x);
+ 1.0 / a
+}
+
#[cfg(test)]
mod tests {
@@ -739,4 +772,50 @@ mod tests {
assert_eq!(ints.value(2), 75);
assert_eq!(ints.value(3), 16);
}
+
+ #[test]
+ fn test_cot_f32() {
+ let args: Vec<ArrayRef> =
+ vec![Arc::new(Float32Array::from(vec![12.1, 30.0, 90.0, -30.0]))];
+ let result = cot(&args).expect("failed to initialize function cot");
+ let floats =
+ as_float32_array(&result).expect("failed to initialize function
cot");
+
+ let expected = Float32Array::from(vec![
+ -1.986_460_4,
+ -0.156_119_96,
+ -0.501_202_8,
+ 0.156_119_96,
+ ]);
+
+ let eps = 1e-6;
+ assert_eq!(floats.len(), 4);
+ assert!((floats.value(0) - expected.value(0)).abs() < eps);
+ assert!((floats.value(1) - expected.value(1)).abs() < eps);
+ assert!((floats.value(2) - expected.value(2)).abs() < eps);
+ assert!((floats.value(3) - expected.value(3)).abs() < eps);
+ }
+
+ #[test]
+ fn test_cot_f64() {
+ let args: Vec<ArrayRef> =
+ vec![Arc::new(Float64Array::from(vec![12.1, 30.0, 90.0, -30.0]))];
+ let result = cot(&args).expect("failed to initialize function cot");
+ let floats =
+ as_float64_array(&result).expect("failed to initialize function
cot");
+
+ let expected = Float64Array::from(vec![
+ -1.986_458_685_881_4,
+ -0.156_119_952_161_6,
+ -0.501_202_783_380_1,
+ 0.156_119_952_161_6,
+ ]);
+
+ let eps = 1e-12;
+ assert_eq!(floats.len(), 4);
+ assert!((floats.value(0) - expected.value(0)).abs() < eps);
+ assert!((floats.value(1) - expected.value(1)).abs() < eps);
+ assert!((floats.value(2) - expected.value(2)).abs() < eps);
+ assert!((floats.value(3) - expected.value(3)).abs() < eps);
+ }
}
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index 89bca57cf3..b26d18947f 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -566,6 +566,7 @@ enum ScalarFunction {
ArrayContains = 100;
Encode = 101;
Decode = 102;
+ Cot = 103;
}
message ScalarFunctionNode {
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index 590b462ad8..09e85f11a7 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -18064,6 +18064,7 @@ impl serde::Serialize for ScalarFunction {
Self::ArrayContains => "ArrayContains",
Self::Encode => "Encode",
Self::Decode => "Decode",
+ Self::Cot => "Cot",
};
serializer.serialize_str(variant)
}
@@ -18178,6 +18179,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction {
"ArrayContains",
"Encode",
"Decode",
+ "Cot",
];
struct GeneratedVisitor;
@@ -18323,6 +18325,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction {
"ArrayContains" => Ok(ScalarFunction::ArrayContains),
"Encode" => Ok(ScalarFunction::Encode),
"Decode" => Ok(ScalarFunction::Decode),
+ "Cot" => Ok(ScalarFunction::Cot),
_ => Err(serde::de::Error::unknown_variant(value, FIELDS)),
}
}
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index 251760f090..516aa325bc 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -2265,6 +2265,7 @@ pub enum ScalarFunction {
ArrayContains = 100,
Encode = 101,
Decode = 102,
+ Cot = 103,
}
impl ScalarFunction {
/// String value of the enum field names used in the ProtoBuf definition.
@@ -2376,6 +2377,7 @@ impl ScalarFunction {
ScalarFunction::ArrayContains => "ArrayContains",
ScalarFunction::Encode => "Encode",
ScalarFunction::Decode => "Decode",
+ ScalarFunction::Cot => "Cot",
}
}
/// Creates an enum from field names used in the ProtoBuf definition.
@@ -2484,6 +2486,7 @@ impl ScalarFunction {
"ArrayContains" => Some(Self::ArrayContains),
"Encode" => Some(Self::Encode),
"Decode" => Some(Self::Decode),
+ "Cot" => Some(Self::Cot),
_ => None,
}
}
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs
b/datafusion/proto/src/logical_plan/from_proto.rs
index 1b48364ad4..8b70480e7d 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -40,8 +40,8 @@ use datafusion_expr::{
array_fill, array_length, array_ndims, array_position, array_positions,
array_prepend, array_remove, array_replace, array_to_string, ascii, asin,
asinh,
atan, atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil,
character_length,
- chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, current_date,
current_time,
- date_bin, date_part, date_trunc, degrees, digest, exp,
+ chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, current_date,
+ current_time, date_bin, date_part, date_trunc, degrees, digest, exp,
expr::{self, InList, Sort, WindowFunction},
factorial, floor, from_unixtime, gcd, lcm, left, ln, log, log10, log2,
logical_plan::{PlanType, StringifiedPlan},
@@ -417,6 +417,7 @@ impl From<&protobuf::ScalarFunction> for
BuiltinScalarFunction {
ScalarFunction::Sin => Self::Sin,
ScalarFunction::Cos => Self::Cos,
ScalarFunction::Tan => Self::Tan,
+ ScalarFunction::Cot => Self::Cot,
ScalarFunction::Asin => Self::Asin,
ScalarFunction::Acos => Self::Acos,
ScalarFunction::Atan => Self::Atan,
@@ -1473,6 +1474,7 @@ pub fn parse_expr(
)),
ScalarFunction::CurrentDate => Ok(current_date()),
ScalarFunction::CurrentTime => Ok(current_time()),
+ ScalarFunction::Cot => Ok(cot(parse_expr(&args[0],
registry)?)),
_ => Err(proto_error(
"Protobuf deserialization error: Unsupported scalar
function",
)),
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs
b/datafusion/proto/src/logical_plan/to_proto.rs
index 8665ca00c3..30bb30950c 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -1344,6 +1344,7 @@ impl TryFrom<&BuiltinScalarFunction> for
protobuf::ScalarFunction {
BuiltinScalarFunction::Sin => Self::Sin,
BuiltinScalarFunction::Cos => Self::Cos,
BuiltinScalarFunction::Tan => Self::Tan,
+ BuiltinScalarFunction::Cot => Self::Cot,
BuiltinScalarFunction::Sinh => Self::Sinh,
BuiltinScalarFunction::Cosh => Self::Cosh,
BuiltinScalarFunction::Tanh => Self::Tanh,