martin-g commented on code in PR #21062: URL: https://github.com/apache/datafusion/pull/21062#discussion_r2964423359
########## datafusion/spark/src/function/math/round.rs: ########## @@ -0,0 +1,659 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::*; +use arrow::datatypes::{ + ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type, + Decimal256Type, Float16Type, Float32Type, Float64Type, Int8Type, Int16Type, + Int32Type, Int64Type, UInt8Type, UInt16Type, UInt32Type, UInt64Type, +}; +use datafusion_common::types::{ + NativeType, logical_float32, logical_float64, logical_int32, +}; +use datafusion_common::{Result, ScalarValue, exec_err, not_impl_err}; +use datafusion_expr::{ + Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + TypeSignatureClass, Volatility, +}; + +/// Spark-compatible `round` expression +/// <https://spark.apache.org/docs/latest/api/sql/index.html#round> +/// +/// Rounds the value of `expr` to `scale` decimal places using HALF_UP rounding mode. +/// Returns the same type as the input expression. +/// +/// - `round(expr)` rounds to 0 decimal places (default scale = 0) +/// - `round(expr, scale)` rounds to `scale` decimal places +/// - For integer types with negative scale: `round(25, -1)` → `20` Review Comment: ```suggestion /// - For integer types with negative scale: `round(25, -1)` → `30` ``` ########## datafusion/spark/src/function/math/round.rs: ########## @@ -0,0 +1,659 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::*; +use arrow::datatypes::{ + ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type, + Decimal256Type, Float16Type, Float32Type, Float64Type, Int8Type, Int16Type, + Int32Type, Int64Type, UInt8Type, UInt16Type, UInt32Type, UInt64Type, +}; +use datafusion_common::types::{ + NativeType, logical_float32, logical_float64, logical_int32, +}; +use datafusion_common::{Result, ScalarValue, exec_err, not_impl_err}; +use datafusion_expr::{ + Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + TypeSignatureClass, Volatility, +}; + +/// Spark-compatible `round` expression +/// <https://spark.apache.org/docs/latest/api/sql/index.html#round> +/// +/// Rounds the value of `expr` to `scale` decimal places using HALF_UP rounding mode. +/// Returns the same type as the input expression. +/// +/// - `round(expr)` rounds to 0 decimal places (default scale = 0) +/// - `round(expr, scale)` rounds to `scale` decimal places +/// - For integer types with negative scale: `round(25, -1)` → `20` +/// - Uses HALF_UP rounding: 2.5 → 3, -2.5 → -3 (away from zero) +/// +/// Supported types: Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, +/// Float16, Float32, Float64, Decimal32, Decimal64, Decimal128, Decimal256 +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkRound { + signature: Signature, +} + +impl Default for SparkRound { + fn default() -> Self { + Self::new() + } +} + +impl SparkRound { + pub fn new() -> Self { + let decimal = Coercion::new_exact(TypeSignatureClass::Decimal); + let integer = Coercion::new_exact(TypeSignatureClass::Integer); + let decimal_places = Coercion::new_implicit( + TypeSignatureClass::Native(logical_int32()), + vec![TypeSignatureClass::Integer], + NativeType::Int32, + ); + let float32 = Coercion::new_exact(TypeSignatureClass::Native(logical_float32())); + let float64 = Coercion::new_implicit( + TypeSignatureClass::Native(logical_float64()), + vec![TypeSignatureClass::Numeric], + NativeType::Float64, + ); + Self { + signature: Signature::one_of( + vec![ + // round(decimal, scale) + TypeSignature::Coercible(vec![ + decimal.clone(), + decimal_places.clone(), + ]), + // round(decimal) + TypeSignature::Coercible(vec![decimal]), + // round(integer, scale) + TypeSignature::Coercible(vec![ + integer.clone(), + decimal_places.clone(), + ]), + // round(integer) + TypeSignature::Coercible(vec![integer]), + // round(float32, scale) + TypeSignature::Coercible(vec![ + float32.clone(), + decimal_places.clone(), + ]), + // round(float32) + TypeSignature::Coercible(vec![float32]), + // round(float64, scale) + TypeSignature::Coercible(vec![float64.clone(), decimal_places]), + // round(float64) + TypeSignature::Coercible(vec![float64]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkRound { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "round" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { + Ok(arg_types[0].clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { + spark_round(&args.args, args.config_options.execution.enable_ansi_mode) + } +} + +/// Extract the scale (decimal places) from the second argument. +/// Returns `Some(0)` if no second argument is provided. +/// Returns `None` if the scale argument is NULL (Spark returns NULL for `round(expr, NULL)`). +fn get_scale(args: &[ColumnarValue]) -> Result<Option<i32>> { + if args.len() < 2 { + return Ok(Some(0)); + } + + match &args[1] { + ColumnarValue::Scalar(ScalarValue::Int8(Some(v))) => Ok(Some(i32::from(*v))), + ColumnarValue::Scalar(ScalarValue::Int16(Some(v))) => Ok(Some(i32::from(*v))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(v))) => Ok(Some(*v)), + ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) => { + i32::try_from(*v).map(Some).map_err(|_| { + (exec_err!("round scale {v} is out of supported i32 range") + as Result<(), _>) + .unwrap_err() + }) + } + ColumnarValue::Scalar(ScalarValue::UInt8(Some(v))) => Ok(Some(i32::from(*v))), + ColumnarValue::Scalar(ScalarValue::UInt16(Some(v))) => Ok(Some(i32::from(*v))), + ColumnarValue::Scalar(ScalarValue::UInt32(Some(v))) => { + i32::try_from(*v).map(Some).map_err(|_| { + (exec_err!("round scale {v} is out of supported i32 range") + as Result<(), _>) + .unwrap_err() + }) + } + ColumnarValue::Scalar(ScalarValue::UInt64(Some(v))) => { + i32::try_from(*v).map(Some).map_err(|_| { + (exec_err!("round scale {v} is out of supported i32 range") + as Result<(), _>) + .unwrap_err() + }) + } + ColumnarValue::Scalar(sv) if sv.is_null() => Ok(None), + other => exec_err!("Unsupported type for round scale: {}", other.data_type()), + } +} + +/// Round a floating-point value to the given number of decimal places using +/// HALF_UP rounding mode (ties round away from zero). +/// +/// This matches Spark's `RoundBase` behaviour for `FloatType` / `DoubleType`, +/// which internally converts the value to `BigDecimal` and rounds with +/// `RoundingMode.HALF_UP`. +/// +/// # Arguments +/// * `value` – the floating-point number to round +/// * `scale` – number of decimal places to keep. +/// - `scale >= 0`: rounds to that many fractional digits +/// (e.g. `round_float(2.345, 2) == 2.35`) +/// - `scale < 0`: rounds to the left of the decimal point +/// (e.g. `round_float(125.0, -1) == 130.0`) +/// +/// # Examples +/// ```text +/// round_float(2.5, 0) → 3.0 // half rounds up +/// round_float(-2.5, 0) → -3.0 // half rounds away from zero +/// round_float(1.4, 0) → 1.0 +/// round_float(125.0, -1) → 130.0 +/// ``` +fn round_float<T: bigdecimal::num_traits::Float>(value: T, scale: i32) -> T { Review Comment: nit: `bigdecimal::num_traits` - this is a bit hacky! The DataFusion workspace has an entry for `num-traits` - https://github.com/apache/datafusion/blob/9885f4bfe88ae9e4df96466d7246b80244b8054a/Cargo.toml#L165 But the datafusion-spark crate does **not** use it - https://github.com/apache/datafusion/blob/main/datafusion/spark/Cargo.toml Here you use the re-exported crate from bigdecimal. I think it would be better to add dependency to num-traits for datafusion-spark. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
