This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 5901df58b2 feat: add bounds for unary math scalar functions (#11584)
5901df58b2 is described below
commit 5901df58b21b8b4e36011744e7ddc17bcb6a37b3
Author: Trent Hauck <[email protected]>
AuthorDate: Wed Jul 24 12:21:13 2024 -0700
feat: add bounds for unary math scalar functions (#11584)
* feat: unary udf function bounds
* feat: add bounds for more types
* feat: remove eprint
* fix: add missing bounds file
* tests: add tests for unary udf bounds
* tests: test f32 and f64
* build: remove unrelated changes
* refactor: better unbounded func name
* tests: fix tests
* refactor: use data_type method
* refactor: add more useful intervals to Interval
* refactor: use typed bounds for (-inf, inf)
* refactor: inf to unbounded
* refactor: add lower/upper pi bounds
* refactor: consts to consts module
* fix: add missing file
* fix: docstring typo
* refactor: remove unused signum bounds
---
datafusion/common/src/scalar/consts.rs | 44 ++++
datafusion/common/src/scalar/mod.rs | 119 ++++++++++
datafusion/expr/src/interval_arithmetic.rs | 32 +++
datafusion/functions/src/macros.rs | 7 +-
datafusion/functions/src/math/bounds.rs | 108 +++++++++
datafusion/functions/src/math/mod.rs | 302 ++++++++++++++++++++++++--
datafusion/functions/src/math/monotonicity.rs | 17 +-
7 files changed, 595 insertions(+), 34 deletions(-)
diff --git a/datafusion/common/src/scalar/consts.rs
b/datafusion/common/src/scalar/consts.rs
new file mode 100644
index 0000000000..efcde65184
--- /dev/null
+++ b/datafusion/common/src/scalar/consts.rs
@@ -0,0 +1,44 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Constants defined for scalar construction.
+
+// PI ~ 3.1415927 in f32
+#[allow(clippy::approx_constant)]
+pub(super) const PI_UPPER_F32: f32 = 3.141593_f32;
+
+// PI ~ 3.141592653589793 in f64
+pub(super) const PI_UPPER_F64: f64 = 3.141592653589794_f64;
+
+// -PI ~ -3.1415927 in f32
+#[allow(clippy::approx_constant)]
+pub(super) const NEGATIVE_PI_LOWER_F32: f32 = -3.141593_f32;
+
+// -PI ~ -3.141592653589793 in f64
+pub(super) const NEGATIVE_PI_LOWER_F64: f64 = -3.141592653589794_f64;
+
+// PI / 2 ~ 1.5707964 in f32
+pub(super) const FRAC_PI_2_UPPER_F32: f32 = 1.5707965_f32;
+
+// PI / 2 ~ 1.5707963267948966 in f64
+pub(super) const FRAC_PI_2_UPPER_F64: f64 = 1.5707963267948967_f64;
+
+// -PI / 2 ~ -1.5707964 in f32
+pub(super) const NEGATIVE_FRAC_PI_2_LOWER_F32: f32 = -1.5707965_f32;
+
+// -PI / 2 ~ -1.5707963267948966 in f64
+pub(super) const NEGATIVE_FRAC_PI_2_LOWER_F64: f64 = -1.5707963267948967_f64;
diff --git a/datafusion/common/src/scalar/mod.rs
b/datafusion/common/src/scalar/mod.rs
index 92ed897e71..286df339ad 100644
--- a/datafusion/common/src/scalar/mod.rs
+++ b/datafusion/common/src/scalar/mod.rs
@@ -17,7 +17,9 @@
//! [`ScalarValue`]: stores single values
+mod consts;
mod struct_builder;
+
use std::borrow::Borrow;
use std::cmp::Ordering;
use std::collections::{HashSet, VecDeque};
@@ -1007,6 +1009,123 @@ impl ScalarValue {
}
}
+ /// Returns a [`ScalarValue`] representing PI
+ pub fn new_pi(datatype: &DataType) -> Result<ScalarValue> {
+ match datatype {
+ DataType::Float32 => Ok(ScalarValue::from(std::f32::consts::PI)),
+ DataType::Float64 => Ok(ScalarValue::from(std::f64::consts::PI)),
+ _ => _internal_err!("PI is not supported for data type: {:?}",
datatype),
+ }
+ }
+
+ /// Returns a [`ScalarValue`] representing PI's upper bound
+ pub fn new_pi_upper(datatype: &DataType) -> Result<ScalarValue> {
+ // TODO: replace the constants with next_up/next_down when
+ // they are stabilized:
https://doc.rust-lang.org/std/primitive.f64.html#method.next_up
+ match datatype {
+ DataType::Float32 => Ok(ScalarValue::from(consts::PI_UPPER_F32)),
+ DataType::Float64 => Ok(ScalarValue::from(consts::PI_UPPER_F64)),
+ _ => {
+ _internal_err!("PI_UPPER is not supported for data type:
{:?}", datatype)
+ }
+ }
+ }
+
+ /// Returns a [`ScalarValue`] representing -PI's lower bound
+ pub fn new_negative_pi_lower(datatype: &DataType) -> Result<ScalarValue> {
+ match datatype {
+ DataType::Float32 =>
Ok(ScalarValue::from(consts::NEGATIVE_PI_LOWER_F32)),
+ DataType::Float64 =>
Ok(ScalarValue::from(consts::NEGATIVE_PI_LOWER_F64)),
+ _ => {
+ _internal_err!("-PI_LOWER is not supported for data type:
{:?}", datatype)
+ }
+ }
+ }
+
+ /// Returns a [`ScalarValue`] representing FRAC_PI_2's upper bound
+ pub fn new_frac_pi_2_upper(datatype: &DataType) -> Result<ScalarValue> {
+ match datatype {
+ DataType::Float32 =>
Ok(ScalarValue::from(consts::FRAC_PI_2_UPPER_F32)),
+ DataType::Float64 =>
Ok(ScalarValue::from(consts::FRAC_PI_2_UPPER_F64)),
+ _ => {
+ _internal_err!(
+ "PI_UPPER/2 is not supported for data type: {:?}",
+ datatype
+ )
+ }
+ }
+ }
+
+ // Returns a [`ScalarValue`] representing FRAC_PI_2's lower bound
+ pub fn new_neg_frac_pi_2_lower(datatype: &DataType) -> Result<ScalarValue>
{
+ match datatype {
+ DataType::Float32 => {
+ Ok(ScalarValue::from(consts::NEGATIVE_FRAC_PI_2_LOWER_F32))
+ }
+ DataType::Float64 => {
+ Ok(ScalarValue::from(consts::NEGATIVE_FRAC_PI_2_LOWER_F64))
+ }
+ _ => {
+ _internal_err!(
+ "-PI/2_LOWER is not supported for data type: {:?}",
+ datatype
+ )
+ }
+ }
+ }
+
+ /// Returns a [`ScalarValue`] representing -PI
+ pub fn new_negative_pi(datatype: &DataType) -> Result<ScalarValue> {
+ match datatype {
+ DataType::Float32 => Ok(ScalarValue::from(-std::f32::consts::PI)),
+ DataType::Float64 => Ok(ScalarValue::from(-std::f64::consts::PI)),
+ _ => _internal_err!("-PI is not supported for data type: {:?}",
datatype),
+ }
+ }
+
+ /// Returns a [`ScalarValue`] representing PI/2
+ pub fn new_frac_pi_2(datatype: &DataType) -> Result<ScalarValue> {
+ match datatype {
+ DataType::Float32 =>
Ok(ScalarValue::from(std::f32::consts::FRAC_PI_2)),
+ DataType::Float64 =>
Ok(ScalarValue::from(std::f64::consts::FRAC_PI_2)),
+ _ => _internal_err!("PI/2 is not supported for data type: {:?}",
datatype),
+ }
+ }
+
+ /// Returns a [`ScalarValue`] representing -PI/2
+ pub fn new_neg_frac_pi_2(datatype: &DataType) -> Result<ScalarValue> {
+ match datatype {
+ DataType::Float32 =>
Ok(ScalarValue::from(-std::f32::consts::FRAC_PI_2)),
+ DataType::Float64 =>
Ok(ScalarValue::from(-std::f64::consts::FRAC_PI_2)),
+ _ => _internal_err!("-PI/2 is not supported for data type: {:?}",
datatype),
+ }
+ }
+
+ /// Returns a [`ScalarValue`] representing infinity
+ pub fn new_infinity(datatype: &DataType) -> Result<ScalarValue> {
+ match datatype {
+ DataType::Float32 => Ok(ScalarValue::from(f32::INFINITY)),
+ DataType::Float64 => Ok(ScalarValue::from(f64::INFINITY)),
+ _ => {
+ _internal_err!("Infinity is not supported for data type:
{:?}", datatype)
+ }
+ }
+ }
+
+ /// Returns a [`ScalarValue`] representing negative infinity
+ pub fn new_neg_infinity(datatype: &DataType) -> Result<ScalarValue> {
+ match datatype {
+ DataType::Float32 => Ok(ScalarValue::from(f32::NEG_INFINITY)),
+ DataType::Float64 => Ok(ScalarValue::from(f64::NEG_INFINITY)),
+ _ => {
+ _internal_err!(
+ "Negative Infinity is not supported for data type: {:?}",
+ datatype
+ )
+ }
+ }
+ }
+
/// Create a zero value in the given type.
pub fn new_zero(datatype: &DataType) -> Result<ScalarValue> {
Ok(match datatype {
diff --git a/datafusion/expr/src/interval_arithmetic.rs
b/datafusion/expr/src/interval_arithmetic.rs
index 18f92334ff..d0dd418c78 100644
--- a/datafusion/expr/src/interval_arithmetic.rs
+++ b/datafusion/expr/src/interval_arithmetic.rs
@@ -332,6 +332,38 @@ impl Interval {
Ok(Self::new(unbounded_endpoint.clone(), unbounded_endpoint))
}
+ /// Creates an interval between -1 to 1.
+ pub fn make_symmetric_unit_interval(data_type: &DataType) -> Result<Self> {
+ Self::try_new(
+ ScalarValue::new_negative_one(data_type)?,
+ ScalarValue::new_one(data_type)?,
+ )
+ }
+
+ /// Create an interval from -π to π.
+ pub fn make_symmetric_pi_interval(data_type: &DataType) -> Result<Self> {
+ Self::try_new(
+ ScalarValue::new_negative_pi_lower(data_type)?,
+ ScalarValue::new_pi_upper(data_type)?,
+ )
+ }
+
+ /// Create an interval from -π/2 to π/2.
+ pub fn make_symmetric_half_pi_interval(data_type: &DataType) ->
Result<Self> {
+ Self::try_new(
+ ScalarValue::new_neg_frac_pi_2_lower(data_type)?,
+ ScalarValue::new_frac_pi_2_upper(data_type)?,
+ )
+ }
+
+ /// Create an interval from 0 to infinity.
+ pub fn make_non_negative_infinity_interval(data_type: &DataType) ->
Result<Self> {
+ Self::try_new(
+ ScalarValue::new_zero(data_type)?,
+ ScalarValue::try_from(data_type)?,
+ )
+ }
+
/// Returns a reference to the lower bound.
pub fn lower(&self) -> &ScalarValue {
&self.lower
diff --git a/datafusion/functions/src/macros.rs
b/datafusion/functions/src/macros.rs
index cae689b3e0..e26c94e1bb 100644
--- a/datafusion/functions/src/macros.rs
+++ b/datafusion/functions/src/macros.rs
@@ -162,7 +162,7 @@ macro_rules! downcast_arg {
/// $UNARY_FUNC: the unary function to apply to the argument
/// $OUTPUT_ORDERING: the output ordering calculation method of the function
macro_rules! make_math_unary_udf {
- ($UDF:ident, $GNAME:ident, $NAME:ident, $UNARY_FUNC:ident,
$OUTPUT_ORDERING:expr) => {
+ ($UDF:ident, $GNAME:ident, $NAME:ident, $UNARY_FUNC:ident,
$OUTPUT_ORDERING:expr, $EVALUATE_BOUNDS:expr) => {
make_udf_function!($NAME::$UDF, $GNAME, $NAME);
mod $NAME {
@@ -172,6 +172,7 @@ macro_rules! make_math_unary_udf {
use arrow::array::{ArrayRef, Float32Array, Float64Array};
use arrow::datatypes::DataType;
use datafusion_common::{exec_err, DataFusionError, Result};
+ use datafusion_expr::interval_arithmetic::Interval;
use datafusion_expr::sort_properties::{ExprProperties,
SortProperties};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature,
Volatility};
@@ -222,6 +223,10 @@ macro_rules! make_math_unary_udf {
$OUTPUT_ORDERING(input)
}
+ fn evaluate_bounds(&self, inputs: &[&Interval]) ->
Result<Interval> {
+ $EVALUATE_BOUNDS(inputs)
+ }
+
fn invoke(&self, args: &[ColumnarValue]) ->
Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(args)?;
diff --git a/datafusion/functions/src/math/bounds.rs
b/datafusion/functions/src/math/bounds.rs
new file mode 100644
index 0000000000..894d2bded5
--- /dev/null
+++ b/datafusion/functions/src/math/bounds.rs
@@ -0,0 +1,108 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use datafusion_common::ScalarValue;
+use datafusion_expr::interval_arithmetic::Interval;
+
+pub(super) fn unbounded_bounds(input: &[&Interval]) -> crate::Result<Interval>
{
+ let data_type = input[0].data_type();
+
+ Interval::make_unbounded(&data_type)
+}
+
+pub(super) fn sin_bounds(input: &[&Interval]) -> crate::Result<Interval> {
+ // sin(x) is bounded by [-1, 1]
+ let data_type = input[0].data_type();
+
+ Interval::make_symmetric_unit_interval(&data_type)
+}
+
+pub(super) fn asin_bounds(input: &[&Interval]) -> crate::Result<Interval> {
+ // asin(x) is bounded by [-π/2, π/2]
+ let data_type = input[0].data_type();
+
+ Interval::make_symmetric_half_pi_interval(&data_type)
+}
+
+pub(super) fn atan_bounds(input: &[&Interval]) -> crate::Result<Interval> {
+ // atan(x) is bounded by [-π/2, π/2]
+ let data_type = input[0].data_type();
+
+ Interval::make_symmetric_half_pi_interval(&data_type)
+}
+
+pub(super) fn acos_bounds(input: &[&Interval]) -> crate::Result<Interval> {
+ // acos(x) is bounded by [0, π]
+ let data_type = input[0].data_type();
+
+ Interval::try_new(
+ ScalarValue::new_zero(&data_type)?,
+ ScalarValue::new_pi_upper(&data_type)?,
+ )
+}
+
+pub(super) fn acosh_bounds(input: &[&Interval]) -> crate::Result<Interval> {
+ // acosh(x) is bounded by [0, ∞)
+ let data_type = input[0].data_type();
+
+ Interval::make_non_negative_infinity_interval(&data_type)
+}
+
+pub(super) fn cos_bounds(input: &[&Interval]) -> crate::Result<Interval> {
+ // cos(x) is bounded by [-1, 1]
+ let data_type = input[0].data_type();
+
+ Interval::make_symmetric_unit_interval(&data_type)
+}
+
+pub(super) fn cosh_bounds(input: &[&Interval]) -> crate::Result<Interval> {
+ // cosh(x) is bounded by [1, ∞)
+ let data_type = input[0].data_type();
+
+ Interval::try_new(
+ ScalarValue::new_one(&data_type)?,
+ ScalarValue::try_from(&data_type)?,
+ )
+}
+
+pub(super) fn exp_bounds(input: &[&Interval]) -> crate::Result<Interval> {
+ // exp(x) is bounded by [0, ∞)
+ let data_type = input[0].data_type();
+
+ Interval::make_non_negative_infinity_interval(&data_type)
+}
+
+pub(super) fn radians_bounds(input: &[&Interval]) -> crate::Result<Interval> {
+ // radians(x) is bounded by (-π, π)
+ let data_type = input[0].data_type();
+
+ Interval::make_symmetric_pi_interval(&data_type)
+}
+
+pub(super) fn sqrt_bounds(input: &[&Interval]) -> crate::Result<Interval> {
+ // sqrt(x) is bounded by [0, ∞)
+ let data_type = input[0].data_type();
+
+ Interval::make_non_negative_infinity_interval(&data_type)
+}
+
+pub(super) fn tanh_bounds(input: &[&Interval]) -> crate::Result<Interval> {
+ // tanh(x) is bounded by (-1, 1)
+ let data_type = input[0].data_type();
+
+ Interval::make_symmetric_unit_interval(&data_type)
+}
diff --git a/datafusion/functions/src/math/mod.rs
b/datafusion/functions/src/math/mod.rs
index 3b32a158b8..1e41fff289 100644
--- a/datafusion/functions/src/math/mod.rs
+++ b/datafusion/functions/src/math/mod.rs
@@ -22,6 +22,7 @@ use datafusion_expr::ScalarUDF;
use std::sync::Arc;
pub mod abs;
+pub mod bounds;
pub mod cot;
pub mod factorial;
pub mod gcd;
@@ -40,36 +41,142 @@ pub mod trunc;
// Create UDFs
make_udf_function!(abs::AbsFunc, ABS, abs);
-make_math_unary_udf!(AcosFunc, ACOS, acos, acos, super::acos_order);
-make_math_unary_udf!(AcoshFunc, ACOSH, acosh, acosh, super::acosh_order);
-make_math_unary_udf!(AsinFunc, ASIN, asin, asin, super::asin_order);
-make_math_unary_udf!(AsinhFunc, ASINH, asinh, asinh, super::asinh_order);
-make_math_unary_udf!(AtanFunc, ATAN, atan, atan, super::atan_order);
-make_math_unary_udf!(AtanhFunc, ATANH, atanh, atanh, super::atanh_order);
+make_math_unary_udf!(
+ AcosFunc,
+ ACOS,
+ acos,
+ acos,
+ super::acos_order,
+ super::bounds::acos_bounds
+);
+make_math_unary_udf!(
+ AcoshFunc,
+ ACOSH,
+ acosh,
+ acosh,
+ super::acosh_order,
+ super::bounds::acosh_bounds
+);
+make_math_unary_udf!(
+ AsinFunc,
+ ASIN,
+ asin,
+ asin,
+ super::asin_order,
+ super::bounds::asin_bounds
+);
+make_math_unary_udf!(
+ AsinhFunc,
+ ASINH,
+ asinh,
+ asinh,
+ super::asinh_order,
+ super::bounds::unbounded_bounds
+);
+make_math_unary_udf!(
+ AtanFunc,
+ ATAN,
+ atan,
+ atan,
+ super::atan_order,
+ super::bounds::atan_bounds
+);
+make_math_unary_udf!(
+ AtanhFunc,
+ ATANH,
+ atanh,
+ atanh,
+ super::atanh_order,
+ super::bounds::unbounded_bounds
+);
make_math_binary_udf!(Atan2, ATAN2, atan2, atan2, super::atan2_order);
-make_math_unary_udf!(CbrtFunc, CBRT, cbrt, cbrt, super::cbrt_order);
-make_math_unary_udf!(CeilFunc, CEIL, ceil, ceil, super::ceil_order);
-make_math_unary_udf!(CosFunc, COS, cos, cos, super::cos_order);
-make_math_unary_udf!(CoshFunc, COSH, cosh, cosh, super::cosh_order);
+make_math_unary_udf!(
+ CbrtFunc,
+ CBRT,
+ cbrt,
+ cbrt,
+ super::cbrt_order,
+ super::bounds::unbounded_bounds
+);
+make_math_unary_udf!(
+ CeilFunc,
+ CEIL,
+ ceil,
+ ceil,
+ super::ceil_order,
+ super::bounds::unbounded_bounds
+);
+make_math_unary_udf!(
+ CosFunc,
+ COS,
+ cos,
+ cos,
+ super::cos_order,
+ super::bounds::cos_bounds
+);
+make_math_unary_udf!(
+ CoshFunc,
+ COSH,
+ cosh,
+ cosh,
+ super::cosh_order,
+ super::bounds::cosh_bounds
+);
make_udf_function!(cot::CotFunc, COT, cot);
make_math_unary_udf!(
DegreesFunc,
DEGREES,
degrees,
to_degrees,
- super::degrees_order
+ super::degrees_order,
+ super::bounds::unbounded_bounds
+);
+make_math_unary_udf!(
+ ExpFunc,
+ EXP,
+ exp,
+ exp,
+ super::exp_order,
+ super::bounds::exp_bounds
);
-make_math_unary_udf!(ExpFunc, EXP, exp, exp, super::exp_order);
make_udf_function!(factorial::FactorialFunc, FACTORIAL, factorial);
-make_math_unary_udf!(FloorFunc, FLOOR, floor, floor, super::floor_order);
+make_math_unary_udf!(
+ FloorFunc,
+ FLOOR,
+ floor,
+ floor,
+ super::floor_order,
+ super::bounds::unbounded_bounds
+);
make_udf_function!(log::LogFunc, LOG, log);
make_udf_function!(gcd::GcdFunc, GCD, gcd);
make_udf_function!(nans::IsNanFunc, ISNAN, isnan);
make_udf_function!(iszero::IsZeroFunc, ISZERO, iszero);
make_udf_function!(lcm::LcmFunc, LCM, lcm);
-make_math_unary_udf!(LnFunc, LN, ln, ln, super::ln_order);
-make_math_unary_udf!(Log2Func, LOG2, log2, log2, super::log2_order);
-make_math_unary_udf!(Log10Func, LOG10, log10, log10, super::log10_order);
+make_math_unary_udf!(
+ LnFunc,
+ LN,
+ ln,
+ ln,
+ super::ln_order,
+ super::bounds::unbounded_bounds
+);
+make_math_unary_udf!(
+ Log2Func,
+ LOG2,
+ log2,
+ log2,
+ super::log2_order,
+ super::bounds::unbounded_bounds
+);
+make_math_unary_udf!(
+ Log10Func,
+ LOG10,
+ log10,
+ log10,
+ super::log10_order,
+ super::bounds::unbounded_bounds
+);
make_udf_function!(nanvl::NanvlFunc, NANVL, nanvl);
make_udf_function!(pi::PiFunc, PI, pi);
make_udf_function!(power::PowerFunc, POWER, power);
@@ -78,16 +185,52 @@ make_math_unary_udf!(
RADIANS,
radians,
to_radians,
- super::radians_order
+ super::radians_order,
+ super::bounds::radians_bounds
);
make_udf_function!(random::RandomFunc, RANDOM, random);
make_udf_function!(round::RoundFunc, ROUND, round);
make_udf_function!(signum::SignumFunc, SIGNUM, signum);
-make_math_unary_udf!(SinFunc, SIN, sin, sin, super::sin_order);
-make_math_unary_udf!(SinhFunc, SINH, sinh, sinh, super::sinh_order);
-make_math_unary_udf!(SqrtFunc, SQRT, sqrt, sqrt, super::sqrt_order);
-make_math_unary_udf!(TanFunc, TAN, tan, tan, super::tan_order);
-make_math_unary_udf!(TanhFunc, TANH, tanh, tanh, super::tanh_order);
+make_math_unary_udf!(
+ SinFunc,
+ SIN,
+ sin,
+ sin,
+ super::sin_order,
+ super::bounds::sin_bounds
+);
+make_math_unary_udf!(
+ SinhFunc,
+ SINH,
+ sinh,
+ sinh,
+ super::sinh_order,
+ super::bounds::unbounded_bounds
+);
+make_math_unary_udf!(
+ SqrtFunc,
+ SQRT,
+ sqrt,
+ sqrt,
+ super::sqrt_order,
+ super::bounds::sqrt_bounds
+);
+make_math_unary_udf!(
+ TanFunc,
+ TAN,
+ tan,
+ tan,
+ super::tan_order,
+ super::bounds::unbounded_bounds
+);
+make_math_unary_udf!(
+ TanhFunc,
+ TANH,
+ tanh,
+ tanh,
+ super::tanh_order,
+ super::bounds::tanh_bounds
+);
make_udf_function!(trunc::TruncFunc, TRUNC, trunc);
pub mod expr_fn {
@@ -175,3 +318,118 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
trunc(),
]
}
+
+#[cfg(test)]
+mod tests {
+ use arrow::datatypes::DataType;
+ use datafusion_common::ScalarValue;
+ use datafusion_expr::interval_arithmetic::Interval;
+
+ fn unbounded_interval(data_type: &DataType) -> Interval {
+ Interval::make_unbounded(data_type).unwrap()
+ }
+
+ fn one_to_inf_interval(data_type: &DataType) -> Interval {
+ Interval::try_new(
+ ScalarValue::new_one(data_type).unwrap(),
+ ScalarValue::try_from(data_type).unwrap(),
+ )
+ .unwrap()
+ }
+
+ fn zero_to_pi_interval(data_type: &DataType) -> Interval {
+ Interval::try_new(
+ ScalarValue::new_zero(data_type).unwrap(),
+ ScalarValue::new_pi_upper(data_type).unwrap(),
+ )
+ .unwrap()
+ }
+
+ fn assert_udf_evaluates_to_bounds(
+ udf: &datafusion_expr::ScalarUDF,
+ interval: Interval,
+ expected: Interval,
+ ) {
+ let input = vec![&interval];
+ let result = udf.evaluate_bounds(&input).unwrap();
+ assert_eq!(
+ result,
+ expected,
+ "Bounds check failed on UDF: {:?}",
+ udf.name()
+ );
+ }
+
+ #[test]
+ fn test_cases() -> crate::Result<()> {
+ let datatypes = [DataType::Float32, DataType::Float64];
+ let cases = datatypes
+ .iter()
+ .flat_map(|data_type| {
+ vec![
+ (
+ super::acos(),
+ unbounded_interval(data_type),
+ zero_to_pi_interval(data_type),
+ ),
+ (
+ super::acosh(),
+ unbounded_interval(data_type),
+
Interval::make_non_negative_infinity_interval(data_type).unwrap(),
+ ),
+ (
+ super::asin(),
+ unbounded_interval(data_type),
+
Interval::make_symmetric_half_pi_interval(data_type).unwrap(),
+ ),
+ (
+ super::atan(),
+ unbounded_interval(data_type),
+
Interval::make_symmetric_half_pi_interval(data_type).unwrap(),
+ ),
+ (
+ super::cos(),
+ unbounded_interval(data_type),
+
Interval::make_symmetric_unit_interval(data_type).unwrap(),
+ ),
+ (
+ super::cosh(),
+ unbounded_interval(data_type),
+ one_to_inf_interval(data_type),
+ ),
+ (
+ super::sin(),
+ unbounded_interval(data_type),
+
Interval::make_symmetric_unit_interval(data_type).unwrap(),
+ ),
+ (
+ super::exp(),
+ unbounded_interval(data_type),
+
Interval::make_non_negative_infinity_interval(data_type).unwrap(),
+ ),
+ (
+ super::sqrt(),
+ unbounded_interval(data_type),
+
Interval::make_non_negative_infinity_interval(data_type).unwrap(),
+ ),
+ (
+ super::radians(),
+ unbounded_interval(data_type),
+
Interval::make_symmetric_pi_interval(data_type).unwrap(),
+ ),
+ (
+ super::sqrt(),
+ unbounded_interval(data_type),
+
Interval::make_non_negative_infinity_interval(data_type).unwrap(),
+ ),
+ ]
+ })
+ .collect::<Vec<_>>();
+
+ for (udf, interval, expected) in cases {
+ assert_udf_evaluates_to_bounds(&udf, interval, expected);
+ }
+
+ Ok(())
+ }
+}
diff --git a/datafusion/functions/src/math/monotonicity.rs
b/datafusion/functions/src/math/monotonicity.rs
index 33c061ee11..52f2ec5171 100644
--- a/datafusion/functions/src/math/monotonicity.rs
+++ b/datafusion/functions/src/math/monotonicity.rs
@@ -15,24 +15,17 @@
// specific language governing permissions and limitations
// under the License.
-use arrow::datatypes::DataType;
use datafusion_common::{exec_err, Result, ScalarValue};
use datafusion_expr::interval_arithmetic::Interval;
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
-fn symmetric_unit_interval(data_type: &DataType) -> Result<Interval> {
- Interval::try_new(
- ScalarValue::new_negative_one(data_type)?,
- ScalarValue::new_one(data_type)?,
- )
-}
-
/// Non-increasing on the interval \[−1, 1\], undefined otherwise.
pub fn acos_order(input: &[ExprProperties]) -> Result<SortProperties> {
let arg = &input[0];
let range = &arg.range;
- let valid_domain = symmetric_unit_interval(&range.lower().data_type())?;
+ let valid_domain =
+ Interval::make_symmetric_unit_interval(&range.lower().data_type())?;
if valid_domain.contains(range)? == Interval::CERTAINLY_TRUE {
Ok(-arg.sort_properties)
@@ -63,7 +56,8 @@ pub fn asin_order(input: &[ExprProperties]) ->
Result<SortProperties> {
let arg = &input[0];
let range = &arg.range;
- let valid_domain = symmetric_unit_interval(&range.lower().data_type())?;
+ let valid_domain =
+ Interval::make_symmetric_unit_interval(&range.lower().data_type())?;
if valid_domain.contains(range)? == Interval::CERTAINLY_TRUE {
Ok(arg.sort_properties)
@@ -87,7 +81,8 @@ pub fn atanh_order(input: &[ExprProperties]) ->
Result<SortProperties> {
let arg = &input[0];
let range = &arg.range;
- let valid_domain = symmetric_unit_interval(&range.lower().data_type())?;
+ let valid_domain =
+ Interval::make_symmetric_unit_interval(&range.lower().data_type())?;
if valid_domain.contains(range)? == Interval::CERTAINLY_TRUE {
Ok(arg.sort_properties)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]