This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 5bb6b35627 fix: use total ordering in the min & max accumulator for
floats (#10627)
5bb6b35627 is described below
commit 5bb6b356277ea1c6f1d7af64e2d66f005d7e1ed4
Author: Weston Pace <[email protected]>
AuthorDate: Fri Jun 7 07:26:53 2024 -0700
fix: use total ordering in the min & max accumulator for floats (#10627)
* fix: use total ordering in the min & max accumulator for floats to match
the ordering used by arrow kernels
* change unit test to expect min to be nan
* changed behavior again since the partial_cmp approach doesn't handle
nulls correctly
* Revert change to describe test. It was not originating from a nan/finite
discrepency but from a null/defined discrepency and we don't want that behavior
to change
* Update the test to check the min function and also verify the result
---
datafusion/physical-expr/src/aggregate/min_max.rs | 60 +++++++++++++++++++++--
1 file changed, 56 insertions(+), 4 deletions(-)
diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs
b/datafusion/physical-expr/src/aggregate/min_max.rs
index 50bd24c487..a6d5054ec1 100644
--- a/datafusion/physical-expr/src/aggregate/min_max.rs
+++ b/datafusion/physical-expr/src/aggregate/min_max.rs
@@ -488,6 +488,20 @@ macro_rules! typed_min_max {
}};
}
+macro_rules! typed_min_max_float {
+ ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{
+ ScalarValue::$SCALAR(match ($VALUE, $DELTA) {
+ (None, None) => None,
+ (Some(a), None) => Some(*a),
+ (None, Some(b)) => Some(*b),
+ (Some(a), Some(b)) => match a.total_cmp(b) {
+ choose_min_max!($OP) => Some(*b),
+ _ => Some(*a),
+ },
+ })
+ }};
+}
+
// min/max of two scalar string values.
macro_rules! typed_min_max_string {
($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{
@@ -500,7 +514,7 @@ macro_rules! typed_min_max_string {
}};
}
-macro_rules! interval_choose_min_max {
+macro_rules! choose_min_max {
(min) => {
std::cmp::Ordering::Greater
};
@@ -512,7 +526,7 @@ macro_rules! interval_choose_min_max {
macro_rules! interval_min_max {
($OP:tt, $LHS:expr, $RHS:expr) => {{
match $LHS.partial_cmp(&$RHS) {
- Some(interval_choose_min_max!($OP)) => $RHS.clone(),
+ Some(choose_min_max!($OP)) => $RHS.clone(),
Some(_) => $LHS.clone(),
None => {
return internal_err!("Comparison error while computing
interval min/max")
@@ -555,10 +569,10 @@ macro_rules! min_max {
typed_min_max!(lhs, rhs, Boolean, $OP)
}
(ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
- typed_min_max!(lhs, rhs, Float64, $OP)
+ typed_min_max_float!(lhs, rhs, Float64, $OP)
}
(ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
- typed_min_max!(lhs, rhs, Float32, $OP)
+ typed_min_max_float!(lhs, rhs, Float32, $OP)
}
(ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
typed_min_max!(lhs, rhs, UInt64, $OP)
@@ -1103,3 +1117,41 @@ impl Accumulator for SlidingMinAccumulator {
std::mem::size_of_val(self) - std::mem::size_of_val(&self.min) +
self.min.size()
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn float_min_max_with_nans() {
+ let pos_nan = f32::NAN;
+ let zero = 0_f32;
+ let neg_inf = f32::NEG_INFINITY;
+
+ let check = |acc: &mut dyn Accumulator, values: &[&[f32]], expected:
f32| {
+ for batch in values.iter() {
+ let batch =
+
Arc::new(Float32Array::from_iter_values(batch.iter().copied()));
+ acc.update_batch(&[batch]).unwrap();
+ }
+ let result = acc.evaluate().unwrap();
+ assert_eq!(result, ScalarValue::Float32(Some(expected)));
+ };
+
+ // This test checks both comparison between batches (which uses the
min_max macro
+ // defined above) and within a batch (which uses the arrow min/max
compute function
+ // and verifies both respect the total order comparison for floats)
+
+ let min = || MinAccumulator::try_new(&DataType::Float32).unwrap();
+ let max = || MaxAccumulator::try_new(&DataType::Float32).unwrap();
+
+ check(&mut min(), &[&[zero], &[pos_nan]], zero);
+ check(&mut min(), &[&[zero, pos_nan]], zero);
+ check(&mut min(), &[&[zero], &[neg_inf]], neg_inf);
+ check(&mut min(), &[&[zero, neg_inf]], neg_inf);
+ check(&mut max(), &[&[zero], &[pos_nan]], pos_nan);
+ check(&mut max(), &[&[zero, pos_nan]], pos_nan);
+ check(&mut max(), &[&[zero], &[neg_inf]], zero);
+ check(&mut max(), &[&[zero, neg_inf]], zero);
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]