martin-g commented on code in PR #19926:
URL: https://github.com/apache/datafusion/pull/19926#discussion_r2720263975
##########
datafusion/functions/src/math/round.rs:
##########
@@ -117,15 +119,74 @@ impl ScalarUDFImpl for RoundFunc {
&self.signature
}
- fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
- Ok(match arg_types[0].clone() {
+ fn return_field_from_args(&self, args: ReturnFieldArgs) ->
Result<FieldRef> {
+ let input_field = &args.arg_fields[0];
+ let input_type = input_field.data_type();
+
+ // Get decimal_places from scalar_arguments
+ // If dp is not a constant scalar, we must keep the original scale
because
+ // we can't determine a single output scale for varying per-row dp
values.
+ let (decimal_places, dp_is_scalar): (i32, bool) =
+ if args.scalar_arguments.len() > 1 {
+ match args.scalar_arguments[1] {
+ Some(ScalarValue::Int32(Some(v))) => (*v, true),
+ Some(ScalarValue::Int64(Some(v))) => (*v as i32, true),
+ _ => (0, false), // dp is a column or null - can't
determine scale
+ }
+ } else {
+ (0, true) // No dp argument means default to 0
+ };
+
+ // Calculate return type based on input type
+ // For decimals: reduce scale to decimal_places (reclaims precision
for integer part)
+ // This matches Spark/DuckDB behavior where ROUND adjusts the scale
+ // BUT only if dp is a constant - otherwise keep original scale
+ let return_type = match input_type {
Float32 => Float32,
- dt @ Decimal128(_, _)
- | dt @ Decimal256(_, _)
- | dt @ Decimal32(_, _)
- | dt @ Decimal64(_, _) => dt,
+ Decimal32(precision, scale) => {
+ if dp_is_scalar {
+ let new_scale = (*scale).min(decimal_places.max(0) as i8);
+ Decimal32(*precision, new_scale)
+ } else {
+ Decimal32(*precision, *scale)
+ }
+ }
+ Decimal64(precision, scale) => {
+ if dp_is_scalar {
+ let new_scale = (*scale).min(decimal_places.max(0) as i8);
+ Decimal64(*precision, new_scale)
+ } else {
+ Decimal64(*precision, *scale)
+ }
+ }
+ Decimal128(precision, scale) => {
+ if dp_is_scalar {
+ let new_scale = (*scale).min(decimal_places.max(0) as i8);
+ Decimal128(*precision, new_scale)
+ } else {
+ Decimal128(*precision, *scale)
+ }
+ }
+ Decimal256(precision, scale) => {
+ if dp_is_scalar {
+ let new_scale = (*scale).min(decimal_places.max(0) as i8);
+ Decimal256(*precision, new_scale)
+ } else {
+ Decimal256(*precision, *scale)
+ }
+ }
_ => Float64,
- })
+ };
+
+ Ok(Arc::new(Field::new(
+ self.name(),
+ return_type,
+ input_field.is_nullable(),
Review Comment:
This should also take into account the `decimal_places` arg.
Postgres:
```
postgres=# SELECT pg_typeof(round(999.9::DECIMAL(4,1))),
round(999.9::DECIMAL(4,1), NULL);
pg_typeof | round
-----------+-------
numeric |
(1 row)
```
Apache Spark:
```
spark-sql (default)> SELECT typeof(round(999.9::DECIMAL(4,1))),
round(999.9::DECIMAL(4,1), NULL);
decimal(4,0) NULL
Time taken: 0.055 seconds, Fetched 1 row(s)
```
DuckDB:
```
D SELECT typeof(round(999.9::DECIMAL(4,1))), round(999.9::DECIMAL(4,1),
NULL);
┌────────────────────────────────────────────┬──────────────────────────────────────────┐
│ typeof(round(CAST(999.9 AS DECIMAL(4,1)))) │ round(CAST(999.9 AS
DECIMAL(4,1)), NULL) │
│ varchar │ int32
│
├────────────────────────────────────────────┼──────────────────────────────────────────┤
│ DECIMAL(4,0) │ NULL
│
└────────────────────────────────────────────┴──────────────────────────────────────────┘
```
##########
datafusion/functions/src/math/round.rs:
##########
@@ -397,12 +501,14 @@ mod test {
decimal_places: Option<ArrayRef>,
) -> Result<ArrayRef, DataFusionError> {
let number_rows = value.len();
+ let return_type = value.data_type().clone();
Review Comment:
This could be wrong for decimals which scale is reduced.
##########
datafusion/functions/src/math/round.rs:
##########
@@ -117,15 +119,74 @@ impl ScalarUDFImpl for RoundFunc {
&self.signature
}
- fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
- Ok(match arg_types[0].clone() {
+ fn return_field_from_args(&self, args: ReturnFieldArgs) ->
Result<FieldRef> {
+ let input_field = &args.arg_fields[0];
+ let input_type = input_field.data_type();
+
+ // Get decimal_places from scalar_arguments
+ // If dp is not a constant scalar, we must keep the original scale
because
+ // we can't determine a single output scale for varying per-row dp
values.
+ let (decimal_places, dp_is_scalar): (i32, bool) =
+ if args.scalar_arguments.len() > 1 {
+ match args.scalar_arguments[1] {
+ Some(ScalarValue::Int32(Some(v))) => (*v, true),
+ Some(ScalarValue::Int64(Some(v))) => (*v as i32, true),
Review Comment:
Better use a safe cast and return an Err if the number is bigger than
i32::MAX (or even i8::MAX)
##########
datafusion/functions/src/math/round.rs:
##########
@@ -117,15 +119,74 @@ impl ScalarUDFImpl for RoundFunc {
&self.signature
}
- fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
- Ok(match arg_types[0].clone() {
+ fn return_field_from_args(&self, args: ReturnFieldArgs) ->
Result<FieldRef> {
+ let input_field = &args.arg_fields[0];
+ let input_type = input_field.data_type();
+
+ // Get decimal_places from scalar_arguments
+ // If dp is not a constant scalar, we must keep the original scale
because
+ // we can't determine a single output scale for varying per-row dp
values.
+ let (decimal_places, dp_is_scalar): (i32, bool) =
Review Comment:
Here `decimal_places` is initialized as `i32` but below it is casted to `i8`
with `as i8` and this may lead to problems. A validation is needed that it is
not bigger than i8::MAX
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]