This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 5bb6b35627 fix: use total ordering in the min & max accumulator for 
floats (#10627)
5bb6b35627 is described below

commit 5bb6b356277ea1c6f1d7af64e2d66f005d7e1ed4
Author: Weston Pace <[email protected]>
AuthorDate: Fri Jun 7 07:26:53 2024 -0700

    fix: use total ordering in the min & max accumulator for floats (#10627)
    
    * fix: use total ordering in the min & max accumulator for floats to match 
the ordering used by arrow kernels
    
    * change unit test to expect min to be nan
    
    * changed behavior again since the partial_cmp approach doesn't handle 
nulls correctly
    
    * Revert change to describe test.  It was not originating from a nan/finite 
discrepency but from a null/defined discrepency and we don't want that behavior 
to change
    
    * Update the test to check the min function and also verify the result
---
 datafusion/physical-expr/src/aggregate/min_max.rs | 60 +++++++++++++++++++++--
 1 file changed, 56 insertions(+), 4 deletions(-)

diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs 
b/datafusion/physical-expr/src/aggregate/min_max.rs
index 50bd24c487..a6d5054ec1 100644
--- a/datafusion/physical-expr/src/aggregate/min_max.rs
+++ b/datafusion/physical-expr/src/aggregate/min_max.rs
@@ -488,6 +488,20 @@ macro_rules! typed_min_max {
     }};
 }
 
+macro_rules! typed_min_max_float {
+    ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{
+        ScalarValue::$SCALAR(match ($VALUE, $DELTA) {
+            (None, None) => None,
+            (Some(a), None) => Some(*a),
+            (None, Some(b)) => Some(*b),
+            (Some(a), Some(b)) => match a.total_cmp(b) {
+                choose_min_max!($OP) => Some(*b),
+                _ => Some(*a),
+            },
+        })
+    }};
+}
+
 // min/max of two scalar string values.
 macro_rules! typed_min_max_string {
     ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{
@@ -500,7 +514,7 @@ macro_rules! typed_min_max_string {
     }};
 }
 
-macro_rules! interval_choose_min_max {
+macro_rules! choose_min_max {
     (min) => {
         std::cmp::Ordering::Greater
     };
@@ -512,7 +526,7 @@ macro_rules! interval_choose_min_max {
 macro_rules! interval_min_max {
     ($OP:tt, $LHS:expr, $RHS:expr) => {{
         match $LHS.partial_cmp(&$RHS) {
-            Some(interval_choose_min_max!($OP)) => $RHS.clone(),
+            Some(choose_min_max!($OP)) => $RHS.clone(),
             Some(_) => $LHS.clone(),
             None => {
                 return internal_err!("Comparison error while computing 
interval min/max")
@@ -555,10 +569,10 @@ macro_rules! min_max {
                 typed_min_max!(lhs, rhs, Boolean, $OP)
             }
             (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
-                typed_min_max!(lhs, rhs, Float64, $OP)
+                typed_min_max_float!(lhs, rhs, Float64, $OP)
             }
             (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
-                typed_min_max!(lhs, rhs, Float32, $OP)
+                typed_min_max_float!(lhs, rhs, Float32, $OP)
             }
             (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
                 typed_min_max!(lhs, rhs, UInt64, $OP)
@@ -1103,3 +1117,41 @@ impl Accumulator for SlidingMinAccumulator {
         std::mem::size_of_val(self) - std::mem::size_of_val(&self.min) + 
self.min.size()
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn float_min_max_with_nans() {
+        let pos_nan = f32::NAN;
+        let zero = 0_f32;
+        let neg_inf = f32::NEG_INFINITY;
+
+        let check = |acc: &mut dyn Accumulator, values: &[&[f32]], expected: 
f32| {
+            for batch in values.iter() {
+                let batch =
+                    
Arc::new(Float32Array::from_iter_values(batch.iter().copied()));
+                acc.update_batch(&[batch]).unwrap();
+            }
+            let result = acc.evaluate().unwrap();
+            assert_eq!(result, ScalarValue::Float32(Some(expected)));
+        };
+
+        // This test checks both comparison between batches (which uses the 
min_max macro
+        // defined above) and within a batch (which uses the arrow min/max 
compute function
+        // and verifies both respect the total order comparison for floats)
+
+        let min = || MinAccumulator::try_new(&DataType::Float32).unwrap();
+        let max = || MaxAccumulator::try_new(&DataType::Float32).unwrap();
+
+        check(&mut min(), &[&[zero], &[pos_nan]], zero);
+        check(&mut min(), &[&[zero, pos_nan]], zero);
+        check(&mut min(), &[&[zero], &[neg_inf]], neg_inf);
+        check(&mut min(), &[&[zero, neg_inf]], neg_inf);
+        check(&mut max(), &[&[zero], &[pos_nan]], pos_nan);
+        check(&mut max(), &[&[zero, pos_nan]], pos_nan);
+        check(&mut max(), &[&[zero], &[neg_inf]], zero);
+        check(&mut max(), &[&[zero, neg_inf]], zero);
+    }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to