This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new f05df7beaf functions: support trunc() function with one or two args 
(#6942)
f05df7beaf is described below

commit f05df7beaf09486be2544ab6e397a480651a2778
Author: Syleechan <[email protected]>
AuthorDate: Fri Jul 21 01:10:07 2023 +0800

    functions: support trunc() function with one or two args (#6942)
    
    * functions: support trunc() function with one or two args
    
    * format code style
    
    * modify truncate method
    
    * adjust code format
    
    * format code
    
    * fix sql test error
---
 .../core/tests/sqllogictests/test_files/scalar.slt |   6 +
 datafusion/expr/src/built_in_function.rs           |  10 +-
 datafusion/expr/src/expr_fn.rs                     |   9 +-
 datafusion/physical-expr/src/functions.rs          |   4 +-
 datafusion/physical-expr/src/math_expressions.rs   | 144 ++++++++++++++++++++-
 datafusion/proto/src/logical_plan/from_proto.rs    |   7 +-
 6 files changed, 173 insertions(+), 7 deletions(-)

diff --git a/datafusion/core/tests/sqllogictests/test_files/scalar.slt 
b/datafusion/core/tests/sqllogictests/test_files/scalar.slt
index 8c5c399c39..6e563a671d 100644
--- a/datafusion/core/tests/sqllogictests/test_files/scalar.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/scalar.slt
@@ -912,6 +912,12 @@ select trunc(a), trunc(b), trunc(c) from small_floats;
 0 0 0
 0 0 1
 
+# trunc with precision
+query RRRRR rowsort
+select trunc(4.267, 3), trunc(1.1234, 2), trunc(-1.1231, 6), trunc(1.2837284, 
2), trunc(1.1, 0);
+----
+4.267 1.12 -1.1231 1.28 1
+
 ## bitwise and
 
 # bitwise and with column and scalar
diff --git a/datafusion/expr/src/built_in_function.rs 
b/datafusion/expr/src/built_in_function.rs
index 74561d9fd7..66c20d362e 100644
--- a/datafusion/expr/src/built_in_function.rs
+++ b/datafusion/expr/src/built_in_function.rs
@@ -1072,6 +1072,15 @@ impl BuiltinScalarFunction {
                 ],
                 self.volatility(),
             ),
+            BuiltinScalarFunction::Trunc => Signature::one_of(
+                vec![
+                    Exact(vec![Float32, Int64]),
+                    Exact(vec![Float64, Int64]),
+                    Exact(vec![Float64]),
+                    Exact(vec![Float32]),
+                ],
+                self.volatility(),
+            ),
             BuiltinScalarFunction::Atan2 => Signature::one_of(
                 vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, 
Float64])],
                 self.volatility(),
@@ -1116,7 +1125,6 @@ impl BuiltinScalarFunction {
             | BuiltinScalarFunction::Sqrt
             | BuiltinScalarFunction::Tan
             | BuiltinScalarFunction::Tanh
-            | BuiltinScalarFunction::Trunc
             | BuiltinScalarFunction::Cot => {
                 // math expressions expect 1 argument of type f64 or f32
                 // priority is given to f64 because e.g. `sqrt(1i32)` is in IR 
(real numbers) and thus we
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index b175fc6f51..30d9580c42 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -502,7 +502,11 @@ scalar_expr!(
 scalar_expr!(Degrees, degrees, num, "converts radians to degrees");
 scalar_expr!(Radians, radians, num, "converts degrees to radians");
 nary_scalar_expr!(Round, round, "round to nearest integer");
-scalar_expr!(Trunc, trunc, num, "truncate toward zero");
+nary_scalar_expr!(
+    Trunc,
+    trunc,
+    "truncate toward zero, with optional precision"
+);
 scalar_expr!(Abs, abs, num, "absolute value");
 scalar_expr!(Signum, signum, num, "sign of the argument (-1, 0, +1) ");
 scalar_expr!(Exp, exp, num, "exponential");
@@ -929,7 +933,8 @@ mod test {
         test_unary_scalar_expr!(Radians, radians);
         test_nary_scalar_expr!(Round, round, input);
         test_nary_scalar_expr!(Round, round, input, decimal_places);
-        test_unary_scalar_expr!(Trunc, trunc);
+        test_nary_scalar_expr!(Trunc, trunc, num);
+        test_nary_scalar_expr!(Trunc, trunc, num, precision);
         test_unary_scalar_expr!(Abs, abs);
         test_unary_scalar_expr!(Signum, signum);
         test_unary_scalar_expr!(Exp, exp);
diff --git a/datafusion/physical-expr/src/functions.rs 
b/datafusion/physical-expr/src/functions.rs
index a92d4335d4..14279d7006 100644
--- a/datafusion/physical-expr/src/functions.rs
+++ b/datafusion/physical-expr/src/functions.rs
@@ -392,7 +392,9 @@ pub fn create_physical_fun(
         BuiltinScalarFunction::Cbrt => Arc::new(math_expressions::cbrt),
         BuiltinScalarFunction::Tan => Arc::new(math_expressions::tan),
         BuiltinScalarFunction::Tanh => Arc::new(math_expressions::tanh),
-        BuiltinScalarFunction::Trunc => Arc::new(math_expressions::trunc),
+        BuiltinScalarFunction::Trunc => {
+            Arc::new(|args| 
make_scalar_function(math_expressions::trunc)(args))
+        }
         BuiltinScalarFunction::Pi => Arc::new(math_expressions::pi),
         BuiltinScalarFunction::Power => {
             Arc::new(|args| 
make_scalar_function(math_expressions::power)(args))
diff --git a/datafusion/physical-expr/src/math_expressions.rs 
b/datafusion/physical-expr/src/math_expressions.rs
index 9a4653c8a0..883c016c04 100644
--- a/datafusion/physical-expr/src/math_expressions.rs
+++ b/datafusion/physical-expr/src/math_expressions.rs
@@ -21,7 +21,7 @@ use arrow::array::ArrayRef;
 use arrow::array::{Float32Array, Float64Array, Int64Array};
 use arrow::datatypes::DataType;
 use datafusion_common::ScalarValue;
-use datafusion_common::ScalarValue::Float32;
+use datafusion_common::ScalarValue::{Float32, Int64};
 use datafusion_common::{DataFusionError, Result};
 use datafusion_expr::ColumnarValue;
 use rand::{thread_rng, Rng};
@@ -158,7 +158,6 @@ math_unary_function!("acosh", acosh);
 math_unary_function!("atanh", atanh);
 math_unary_function!("floor", floor);
 math_unary_function!("ceil", ceil);
-math_unary_function!("trunc", trunc);
 math_unary_function!("abs", abs);
 math_unary_function!("signum", signum);
 math_unary_function!("exp", exp);
@@ -530,6 +529,75 @@ fn compute_cot64(x: f64) -> f64 {
     1.0 / a
 }
 
+/// Truncate(numeric, decimalPrecision) and trunc(numeric) SQL function
+pub fn trunc(args: &[ArrayRef]) -> Result<ArrayRef> {
+    if args.len() != 1 && args.len() != 2 {
+        return Err(DataFusionError::Internal(format!(
+            "truncate function requires one or two arguments, got {}",
+            args.len()
+        )));
+    }
+
+    //if only one arg then invoke toolchain trunc(num) and precision = 0 by 
default
+    //or then invoke the compute_truncate method to process precision
+    let num = &args[0];
+    let precision = if args.len() == 1 {
+        ColumnarValue::Scalar(Int64(Some(0)))
+    } else {
+        ColumnarValue::Array(args[1].clone())
+    };
+
+    match args[0].data_type() {
+        DataType::Float64 => match precision {
+            ColumnarValue::Scalar(Int64(Some(0))) => Ok(Arc::new(
+                make_function_scalar_inputs!(num, "num", Float64Array, { 
f64::trunc }),
+            ) as ArrayRef),
+            ColumnarValue::Array(precision) => 
Ok(Arc::new(make_function_inputs2!(
+                num,
+                precision,
+                "x",
+                "y",
+                Float64Array,
+                Int64Array,
+                { compute_truncate64 }
+            )) as ArrayRef),
+            _ => Err(DataFusionError::Internal(
+                "trunc function requires a scalar or array for 
precision".to_string(),
+            )),
+        },
+        DataType::Float32 => match precision {
+            ColumnarValue::Scalar(Int64(Some(0))) => Ok(Arc::new(
+                make_function_scalar_inputs!(num, "num", Float32Array, { 
f32::trunc }),
+            ) as ArrayRef),
+            ColumnarValue::Array(precision) => 
Ok(Arc::new(make_function_inputs2!(
+                num,
+                precision,
+                "x",
+                "y",
+                Float32Array,
+                Int64Array,
+                { compute_truncate32 }
+            )) as ArrayRef),
+            _ => Err(DataFusionError::Internal(
+                "trunc function requires a scalar or array for 
precision".to_string(),
+            )),
+        },
+        other => Err(DataFusionError::Internal(format!(
+            "Unsupported data type {other:?} for function trunc"
+        ))),
+    }
+}
+
+fn compute_truncate32(x: f32, y: i64) -> f32 {
+    let factor = 10.0_f32.powi(y as i32);
+    (x * factor).round() / factor
+}
+
+fn compute_truncate64(x: f64, y: i64) -> f64 {
+    let factor = 10.0_f64.powi(y as i32);
+    (x * factor).round() / factor
+}
+
 #[cfg(test)]
 mod tests {
 
@@ -818,4 +886,76 @@ mod tests {
         assert!((floats.value(2) - expected.value(2)).abs() < eps);
         assert!((floats.value(3) - expected.value(3)).abs() < eps);
     }
+
+    #[test]
+    fn test_truncate_32() {
+        let args: Vec<ArrayRef> = vec![
+            Arc::new(Float32Array::from(vec![
+                15.0,
+                1_234.267_8,
+                1_233.123_4,
+                3.312_979_2,
+                -21.123_4,
+            ])),
+            Arc::new(Int64Array::from(vec![0, 3, 2, 5, 6])),
+        ];
+
+        let result = trunc(&args).expect("failed to initialize function 
truncate");
+        let floats =
+            as_float32_array(&result).expect("failed to initialize function 
truncate");
+
+        assert_eq!(floats.len(), 5);
+        assert_eq!(floats.value(0), 15.0);
+        assert_eq!(floats.value(1), 1_234.268);
+        assert_eq!(floats.value(2), 1_233.12);
+        assert_eq!(floats.value(3), 3.312_98);
+        assert_eq!(floats.value(4), -21.123_4);
+    }
+
+    #[test]
+    fn test_truncate_64() {
+        let args: Vec<ArrayRef> = vec![
+            Arc::new(Float64Array::from(vec![
+                5.0,
+                234.267_812_176,
+                123.123_456_789,
+                123.312_979_313_2,
+                -321.123_1,
+            ])),
+            Arc::new(Int64Array::from(vec![0, 3, 2, 5, 6])),
+        ];
+
+        let result = trunc(&args).expect("failed to initialize function 
truncate");
+        let floats =
+            as_float64_array(&result).expect("failed to initialize function 
truncate");
+
+        assert_eq!(floats.len(), 5);
+        assert_eq!(floats.value(0), 5.0);
+        assert_eq!(floats.value(1), 234.268);
+        assert_eq!(floats.value(2), 123.12);
+        assert_eq!(floats.value(3), 123.312_98);
+        assert_eq!(floats.value(4), -321.123_1);
+    }
+
+    #[test]
+    fn test_truncate_64_one_arg() {
+        let args: Vec<ArrayRef> = vec![Arc::new(Float64Array::from(vec![
+            5.0,
+            234.267_812,
+            123.123_45,
+            123.312_979_313_2,
+            -321.123,
+        ]))];
+
+        let result = trunc(&args).expect("failed to initialize function 
truncate");
+        let floats =
+            as_float64_array(&result).expect("failed to initialize function 
truncate");
+
+        assert_eq!(floats.len(), 5);
+        assert_eq!(floats.value(0), 5.0);
+        assert_eq!(floats.value(1), 234.0);
+        assert_eq!(floats.value(2), 123.0);
+        assert_eq!(floats.value(3), 123.0);
+        assert_eq!(floats.value(4), -321.0);
+    }
 }
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs 
b/datafusion/proto/src/logical_plan/from_proto.rs
index 202de7df08..a3718090ed 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -1293,7 +1293,12 @@ pub fn parse_expr(
                         .map(|expr| parse_expr(expr, registry))
                         .collect::<Result<Vec<_>, _>>()?,
                 )),
-                ScalarFunction::Trunc => Ok(trunc(parse_expr(&args[0], 
registry)?)),
+                ScalarFunction::Trunc => Ok(trunc(
+                    args.to_owned()
+                        .iter()
+                        .map(|expr| parse_expr(expr, registry))
+                        .collect::<Result<Vec<_>, _>>()?,
+                )),
                 ScalarFunction::Abs => Ok(abs(parse_expr(&args[0], 
registry)?)),
                 ScalarFunction::Signum => Ok(signum(parse_expr(&args[0], 
registry)?)),
                 ScalarFunction::OctetLength => {

Reply via email to