Jefffrey commented on code in PR #22813:
URL: https://github.com/apache/datafusion/pull/22813#discussion_r3377334716
##########
datafusion/spark/src/function/math/round.rs:
##########
@@ -187,20 +190,43 @@ fn get_scale(args: &[ColumnarValue]) ->
Result<Option<i32>> {
/// round_float(125.0, -1) → 130.0
/// ```
fn round_float<T: num_traits::Float>(value: T, scale: i32) -> T {
- if scale >= 0 {
- let factor = T::from(10.0f64.powi(scale)).unwrap_or_else(T::infinity);
- if factor.is_infinite() {
- // Very large positive scale — value is already precise enough,
return as-is
- return value;
- }
- (value * factor).round() / factor
- } else {
- let factor = T::from(10.0f64.powi(-scale)).unwrap_or_else(T::infinity);
- if factor.is_infinite() {
- // Very large negative scale — any finite value rounds to 0
- return T::zero();
- }
- (value / factor).round() * factor
+ // Widen to f64 first. For f32 inputs this matches Spark's `f.toDouble`
+ // step (FloatType: `BigDecimal(f.toDouble).setScale(..).toFloat`), which
+ // exposes the binary-float error before rounding. For f64 it is a no-op.
+ let Some(d) = value.to_f64() else {
+ return value;
+ };
+
+ // Spark returns NaN / ±Inf unchanged; BigDecimal cannot represent them.
+ if !d.is_finite() {
+ return value;
+ }
+
+ // `d.to_string()` produces the shortest round-trip decimal string,
matching
+ // Scala's `BigDecimal(d) = java.math.BigDecimal.valueOf(d)` semantics. So
+ // `round(1.255_f64, 2)` parses "1.255" and rounds to 1.26 (not the naive
+ // binary-float 1.25).
+ let Ok(bd) = BigDecimal::from_str(&d.to_string()) else {
+ // Should not happen for a finite f64, but fall back gracefully.
+ return value;
+ };
+
+ // A finite f64 carries at most ~324 fractional decimal digits and
saturates
+ // below ~1e309 in magnitude, so any `scale` past those bounds is already a
+ // no-op (large positive) or collapses the value to zero (large negative).
+ // Clamp before `with_scale_round` so adversarial input such as
+ // `round(x, i32::MAX)` cannot drive an unbounded `10^scale` BigInt
+ // allocation. The clamp is exact for every finite f64.
+ let clamped_scale = i64::from(scale).clamp(-340, 340);
+
Review Comment:
We might need to error if following Spark semantics here?
```python
>>> spark.version
'4.1.2'
>>> spark.sql("select round(1.255::double, 2147483647)").show()
Traceback (most recent call last):
File "<python-input-4>", line 1, in <module>
spark.sql("select round(1.255::double, 2147483647)").show()
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^
File
"/Users/jeffrey/.cache/uv/archive-v0/GIQgMkXRrHZBaiUVcMOta/lib/python3.13/site-packages/pyspark/sql/classic/dataframe.py",
line 285, in show
print(self._show_string(n, truncate, vertical))
~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^
File
"/Users/jeffrey/.cache/uv/archive-v0/GIQgMkXRrHZBaiUVcMOta/lib/python3.13/site-packages/pyspark/sql/classic/dataframe.py",
line 303, in _show_string
return self._jdf.showString(n, 20, vertical)
~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File
"/Users/jeffrey/.cache/uv/archive-v0/GIQgMkXRrHZBaiUVcMOta/lib/python3.13/site-packages/py4j/java_gateway.py",
line 1362, in __call__
return_value = get_return_value(
answer, self.gateway_client, self.target_id, self.name)
File
"/Users/jeffrey/.cache/uv/archive-v0/GIQgMkXRrHZBaiUVcMOta/lib/python3.13/site-packages/pyspark/errors/exceptions/captured.py",
line 269, in deco
raise converted from None
pyspark.errors.exceptions.captured.ArithmeticException: BigInteger would
overflow supported range
```
##########
datafusion/spark/src/function/math/round.rs:
##########
@@ -187,20 +190,43 @@ fn get_scale(args: &[ColumnarValue]) ->
Result<Option<i32>> {
/// round_float(125.0, -1) → 130.0
/// ```
fn round_float<T: num_traits::Float>(value: T, scale: i32) -> T {
- if scale >= 0 {
- let factor = T::from(10.0f64.powi(scale)).unwrap_or_else(T::infinity);
- if factor.is_infinite() {
- // Very large positive scale — value is already precise enough,
return as-is
- return value;
- }
- (value * factor).round() / factor
- } else {
- let factor = T::from(10.0f64.powi(-scale)).unwrap_or_else(T::infinity);
- if factor.is_infinite() {
- // Very large negative scale — any finite value rounds to 0
- return T::zero();
- }
- (value / factor).round() * factor
+ // Widen to f64 first. For f32 inputs this matches Spark's `f.toDouble`
+ // step (FloatType: `BigDecimal(f.toDouble).setScale(..).toFloat`), which
+ // exposes the binary-float error before rounding. For f64 it is a no-op.
+ let Some(d) = value.to_f64() else {
+ return value;
+ };
+
+ // Spark returns NaN / ±Inf unchanged; BigDecimal cannot represent them.
+ if !d.is_finite() {
+ return value;
+ }
+
+ // `d.to_string()` produces the shortest round-trip decimal string,
matching
+ // Scala's `BigDecimal(d) = java.math.BigDecimal.valueOf(d)` semantics. So
+ // `round(1.255_f64, 2)` parses "1.255" and rounds to 1.26 (not the naive
+ // binary-float 1.25).
+ let Ok(bd) = BigDecimal::from_str(&d.to_string()) else {
+ // Should not happen for a finite f64, but fall back gracefully.
+ return value;
+ };
Review Comment:
something i find interesting is apparently the spark code for this differs a
bit. for `nullSafeEval`:
```scala
case DoubleType =>
val d = input1.asInstanceOf[Double]
if (d.isNaN || d.isInfinite) {
d
} else {
BigDecimal(d).setScale(_scale, mode).toDouble
}
```
-
https://github.com/apache/spark/blob/0993d4345969dfe16b334598dc80a452e4a270f7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L1609-L1615
- uses `BigDecimal(double val)`
meanwhile for `doCodeGen`:
```scala
case DoubleType => // if child eval to NaN or Infinity, just return it.
s"""
if (Double.isNaN(${ce.value}) || Double.isInfinite(${ce.value})) {
${ev.value} = ${ce.value};
} else {
${ev.value} = java.math.BigDecimal.valueOf(${ce.value}).
setScale(${_scale},
java.math.BigDecimal.${modeStr}).doubleValue();
}"""
```
-
https://github.com/apache/spark/blob/0993d4345969dfe16b334598dc80a452e4a270f7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L1670-L1678
- uses `BigDecimal.valueOf(double val)`
do we need to consider this?
##########
datafusion/spark/src/function/math/round.rs:
##########
@@ -187,20 +190,43 @@ fn get_scale(args: &[ColumnarValue]) ->
Result<Option<i32>> {
/// round_float(125.0, -1) → 130.0
/// ```
fn round_float<T: num_traits::Float>(value: T, scale: i32) -> T {
- if scale >= 0 {
- let factor = T::from(10.0f64.powi(scale)).unwrap_or_else(T::infinity);
- if factor.is_infinite() {
- // Very large positive scale — value is already precise enough,
return as-is
- return value;
- }
- (value * factor).round() / factor
- } else {
- let factor = T::from(10.0f64.powi(-scale)).unwrap_or_else(T::infinity);
- if factor.is_infinite() {
- // Very large negative scale — any finite value rounds to 0
- return T::zero();
- }
- (value / factor).round() * factor
+ // Widen to f64 first. For f32 inputs this matches Spark's `f.toDouble`
+ // step (FloatType: `BigDecimal(f.toDouble).setScale(..).toFloat`), which
+ // exposes the binary-float error before rounding. For f64 it is a no-op.
+ let Some(d) = value.to_f64() else {
Review Comment:
we could also do it like so
```rust
fn round_float<T: num_traits::Float + Into<f64>>(value: T, scale: i32) -> T {
// Spark returns NaN / ±Inf unchanged; BigDecimal cannot represent them.
if !value.is_finite() {
return value;
}
// Spark always widens f32: `BigDecimal(f.toDouble).setScale(..).toFloat`
// This exposes the binary-float error before rounding.
let d: f64 = value.into();
```
- can check finiteness without f64 coversion
- can cast from f32 to f64 without loss so dont need the option unwrapping
- streamline comment
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]