Jefffrey commented on code in PR #19926:
URL: https://github.com/apache/datafusion/pull/19926#discussion_r2724895854


##########
datafusion/functions/src/math/round.rs:
##########
@@ -117,15 +209,135 @@ impl ScalarUDFImpl for RoundFunc {
         &self.signature
     }
 
-    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
-        Ok(match arg_types[0].clone() {
+    fn return_field_from_args(&self, args: ReturnFieldArgs) -> 
Result<FieldRef> {
+        let input_field = &args.arg_fields[0];
+        let input_type = input_field.data_type();
+
+        // Get decimal_places from scalar_arguments
+        // If dp is not a constant scalar, we must keep the original scale 
because
+        // we can't determine a single output scale for varying per-row dp 
values.
+        let (decimal_places, dp_is_scalar) = match 
args.scalar_arguments.get(1) {
+            None => (0, true),        // No dp argument means default to 0
+            Some(None) => (0, false), // dp is a column
+            Some(Some(ScalarValue::Int32(Some(v)))) => (*v, true),
+            Some(Some(ScalarValue::Int64(Some(v)))) => {
+                let decimal_places = *v;
+                let v = i32::try_from(decimal_places).map_err(|_| {
+                    datafusion_common::DataFusionError::Execution(format!(
+                        "round decimal_places {decimal_places} is out of 
supported i32 range"
+                    ))
+                })?;
+                (v, true)
+            }
+            Some(Some(scalar)) if scalar.is_null() => (0, true), // null dp => 
output is null
+            Some(Some(other)) => {
+                return exec_err!(
+                    "Unexpected datatype for decimal_places: {}",
+                    other.data_type()
+                );
+            }
+        };
+
+        // Calculate return type based on input type
+        // For decimals: reduce scale to decimal_places (reclaims precision 
for integer part)
+        // This matches Spark/DuckDB behavior where ROUND adjusts the scale
+        // BUT only if dp is a constant - otherwise keep original scale and add
+        // extra precision to accommodate potential carry-over.
+        let return_type = match input_type {
             Float32 => Float32,
-            dt @ Decimal128(_, _)
-            | dt @ Decimal256(_, _)
-            | dt @ Decimal32(_, _)
-            | dt @ Decimal64(_, _) => dt,
+            Decimal32(precision, scale) => {

Review Comment:
   Can deduplicate like so
   
   ```rust
   fn calculate_new_precision_scale<T: DecimalType>(
       precision: u8,
       scale: i8,
       dp_is_scalar: bool,
       decimal_places: i32,
   ) -> Result<DataType> {
       if dp_is_scalar {
           let new_scale = output_scale_for_decimal(scale, decimal_places)?;
           let new_precision = if scale == 0
               && decimal_places < 0
               && decimal_places
                   .checked_neg()
                   .map(|abs| abs <= i32::from(precision))
                   .unwrap_or(false)
           {
               precision.saturating_add(1).min(T::MAX_PRECISION)
           } else {
               precision
           };
           Ok(T::TYPE_CONSTRUCTOR(new_precision, new_scale))
       } else {
           let new_precision = 
precision.saturating_add(1).min(T::MAX_PRECISION);
           Ok(T::TYPE_CONSTRUCTOR(new_precision, scale))
       }
   }
   ```
   
   And use
   
   ```rust
               Decimal32(precision, scale) => {
                   calculate_new_precision_scale::<Decimal32Type>(
                       *precision,
                       *scale,
                       dp_is_scalar,
                       decimal_places,
                   )?
               }
   ```



##########
datafusion/functions/src/math/round.rs:
##########
@@ -168,28 +386,96 @@ impl ScalarUDFImpl for RoundFunc {
                     let rounded = round_float(*v, dp)?;
                     Ok(ColumnarValue::Scalar(ScalarValue::from(rounded)))
                 }
-                ScalarValue::Decimal128(Some(v), precision, scale) => {
-                    let rounded = round_decimal(*v, *scale, dp)?;
+                ScalarValue::Decimal128(Some(v), _precision, scale) => {
+                    let (out_precision, out_scale) =
+                        if let Decimal128(p, s) = args.return_type() {
+                            (*p, *s)
+                        } else {
+                            return internal_err!(
+                                "Unexpected return type for decimal128 round: 
{}",
+                                args.return_type()
+                            );
+                        };
+                    let rounded = round_decimal(*v, *scale, out_scale, dp)?;
+                    let rounded = if out_precision == DECIMAL128_MAX_PRECISION
+                        && *scale == 0
+                        && dp < 0
+                    {
+                        validate_decimal128_precision(rounded, out_precision)
+                    } else {
+                        Ok(rounded)
+                    }?;
                     let scalar =
-                        ScalarValue::Decimal128(Some(rounded), *precision, 
*scale);
+                        ScalarValue::Decimal128(Some(rounded), out_precision, 
out_scale);
                     Ok(ColumnarValue::Scalar(scalar))
                 }
-                ScalarValue::Decimal256(Some(v), precision, scale) => {
-                    let rounded = round_decimal(*v, *scale, dp)?;
+                ScalarValue::Decimal256(Some(v), _precision, scale) => {
+                    let (out_precision, out_scale) =
+                        if let Decimal256(p, s) = args.return_type() {
+                            (*p, *s)
+                        } else {
+                            return internal_err!(
+                                "Unexpected return type for decimal256 round: 
{}",
+                                args.return_type()
+                            );
+                        };
+                    let rounded = round_decimal(*v, *scale, out_scale, dp)?;
+                    let rounded = if out_precision == DECIMAL256_MAX_PRECISION
+                        && *scale == 0
+                        && dp < 0
+                    {
+                        validate_decimal256_precision(rounded, out_precision)
+                    } else {
+                        Ok(rounded)
+                    }?;
                     let scalar =
-                        ScalarValue::Decimal256(Some(rounded), *precision, 
*scale);
+                        ScalarValue::Decimal256(Some(rounded), out_precision, 
out_scale);
                     Ok(ColumnarValue::Scalar(scalar))
                 }
-                ScalarValue::Decimal64(Some(v), precision, scale) => {
-                    let rounded = round_decimal(*v, *scale, dp)?;
+                ScalarValue::Decimal64(Some(v), _precision, scale) => {
+                    let (out_precision, out_scale) =
+                        if let Decimal64(p, s) = args.return_type() {
+                            (*p, *s)
+                        } else {
+                            return internal_err!(
+                                "Unexpected return type for decimal64 round: 
{}",
+                                args.return_type()
+                            );
+                        };
+                    let rounded = round_decimal(*v, *scale, out_scale, dp)?;
+                    let rounded = if out_precision == DECIMAL64_MAX_PRECISION
+                        && *scale == 0
+                        && dp < 0
+                    {
+                        validate_decimal64_precision(rounded, out_precision)
+                    } else {
+                        Ok(rounded)
+                    }?;
                     let scalar =
-                        ScalarValue::Decimal64(Some(rounded), *precision, 
*scale);
+                        ScalarValue::Decimal64(Some(rounded), out_precision, 
out_scale);
                     Ok(ColumnarValue::Scalar(scalar))
                 }
-                ScalarValue::Decimal32(Some(v), precision, scale) => {
-                    let rounded = round_decimal(*v, *scale, dp)?;
+                ScalarValue::Decimal32(Some(v), _precision, scale) => {
+                    let (out_precision, out_scale) =
+                        if let Decimal32(p, s) = args.return_type() {

Review Comment:
   Might be better to pull this destructuring of return type into the match, 
which can enable further deduplication across the decimal arms



##########
datafusion/functions/src/math/round.rs:
##########
@@ -24,20 +24,112 @@ use arrow::datatypes::DataType::{
     Decimal32, Decimal64, Decimal128, Decimal256, Float32, Float64,
 };
 use arrow::datatypes::{
-    ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type,
-    Decimal256Type, Float32Type, Float64Type, Int32Type,
+    ArrowNativeTypeOp, DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION,
+    DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DataType, 
Decimal32Type,
+    Decimal64Type, Decimal128Type, Decimal256Type, Float32Type, Float64Type, 
Int32Type,
+    MAX_DECIMAL32_FOR_EACH_PRECISION, MAX_DECIMAL64_FOR_EACH_PRECISION,
+    MAX_DECIMAL128_FOR_EACH_PRECISION, MAX_DECIMAL256_FOR_EACH_PRECISION,
+    MIN_DECIMAL32_FOR_EACH_PRECISION, MIN_DECIMAL64_FOR_EACH_PRECISION,
+    MIN_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL256_FOR_EACH_PRECISION, i256,
 };
+use arrow::datatypes::{Field, FieldRef};
 use arrow::error::ArrowError;
 use datafusion_common::types::{
     NativeType, logical_float32, logical_float64, logical_int32,
 };
 use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
 use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
 use datafusion_expr::{
-    Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, 
Signature,
-    TypeSignature, TypeSignatureClass, Volatility,
+    Coercion, ColumnarValue, Documentation, ReturnFieldArgs, 
ScalarFunctionArgs,
+    ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility,
 };
 use datafusion_macros::user_doc;
+use std::sync::Arc;
+
+fn output_scale_for_decimal(input_scale: i8, decimal_places: i32) -> 
Result<i8> {
+    let new_scale = i32::from(input_scale).min(decimal_places.max(0));
+    i8::try_from(new_scale).map_err(|_| {
+        datafusion_common::DataFusionError::Internal(format!(
+            "Computed decimal scale {new_scale} is out of range for i8"
+        ))
+    })
+}
+
+fn validate_decimal32_precision(value: i32, precision: u8) -> Result<i32, 
ArrowError> {

Review Comment:
   Use 
https://docs.rs/arrow/latest/arrow/datatypes/trait.DecimalType.html#tymethod.validate_decimal_precision



##########
datafusion/functions/src/math/round.rs:
##########
@@ -117,15 +209,135 @@ impl ScalarUDFImpl for RoundFunc {
         &self.signature
     }
 
-    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
-        Ok(match arg_types[0].clone() {
+    fn return_field_from_args(&self, args: ReturnFieldArgs) -> 
Result<FieldRef> {
+        let input_field = &args.arg_fields[0];
+        let input_type = input_field.data_type();
+
+        // Get decimal_places from scalar_arguments
+        // If dp is not a constant scalar, we must keep the original scale 
because
+        // we can't determine a single output scale for varying per-row dp 
values.
+        let (decimal_places, dp_is_scalar) = match 
args.scalar_arguments.get(1) {

Review Comment:
   Maybe these should be folded into an `Option` if `decimal_places` is only 
used if `dp_is_scalar` is true



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to