This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 8de7ea45e Fix is_distinct from for float NaN values (#5446)
8de7ea45e is described below

commit 8de7ea45e96bfc1931ae51acb5a3abed7511dc7b
Author: comphead <[email protected]>
AuthorDate: Fri Mar 3 04:25:11 2023 -0800

    Fix is_distinct from for float NaN values (#5446)
    
    * impl is_distinct for floats
    
    * impl is_distinct for floats.comments
---
 .../test_files/pg_compat/pg_compat_simple.slt      |  18 +++
 .../core/tests/sqllogictests/test_files/select.slt |  26 +++++
 datafusion/physical-expr/src/expressions/binary.rs |  59 +++++++---
 .../src/expressions/binary/kernels_arrow.rs        | 122 +++++++++++++++++----
 4 files changed, 188 insertions(+), 37 deletions(-)

diff --git 
a/datafusion/core/tests/sqllogictests/test_files/pg_compat/pg_compat_simple.slt 
b/datafusion/core/tests/sqllogictests/test_files/pg_compat/pg_compat_simple.slt
index 6026d07b6..5e9eb07f7 100644
--- 
a/datafusion/core/tests/sqllogictests/test_files/pg_compat/pg_compat_simple.slt
+++ 
b/datafusion/core/tests/sqllogictests/test_files/pg_compat/pg_compat_simple.slt
@@ -572,6 +572,24 @@ e 4 97 -13181 2047637360 6176835796788944083 158 53000 
2042457019 97260165026400
 e 5 -86 32514 -467659022 -8012578250188146150 254 2684 2861911482 
2126626171973341689 0.12559289 0.014793053078 gxfHWUF8XgY2KdFxigxvNEXe2V2XMl
 e 5 64 -26526 1689098844 8950618259486183091 224 45253 662099130 
16127995415060805595 0.2897315 0.575945048386 56MZa5O1hVtX4c5sbnCfxuX5kDChqI
 
+# distinct_from logic for floats
+query BBBBBBBBBBB
+select 
+    'nan'::float is distinct from 'nan'::float v7,
+    'nan'::float is not distinct from 'nan'::float v8,
+    'nan'::float is not distinct from null v9,
+    'nan'::float is distinct from null v10,
+    null is distinct from 'nan'::float v11,
+    null is not distinct from 'nan'::float v12,
+    1::float is distinct from 2::float v13,
+    'nan'::float is distinct from 1::float v14,
+    'nan'::float is not distinct from 1::float v15,
+    1::float is not distinct from null v16,
+    1::float is distinct from null v17
+;
+----
+false true false true true false true true false false true
+
 ########
 # Clean up after the test
 ########
diff --git a/datafusion/core/tests/sqllogictests/test_files/select.slt 
b/datafusion/core/tests/sqllogictests/test_files/select.slt
index a1bc52431..a3fcb8f70 100644
--- a/datafusion/core/tests/sqllogictests/test_files/select.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/select.slt
@@ -80,3 +80,29 @@ select '1' from foo order by column1;
 # foo distinct order by
 statement error DataFusion error: Error during planning: For SELECT DISTINCT, 
ORDER BY expressions column1 must appear in select list
 select distinct '1' from foo order by column1;
+
+# distincts for float nan
+query BBBBBBBBBBBBBBBBB
+select 
+    'nan'::double is distinct from 'nan'::double v1,
+    'nan'::double is not distinct from 'nan'::double v2,
+    'nan'::double is not distinct from null v3,
+    'nan'::double is distinct from null v4,
+    null is distinct from 'nan'::double v5,
+    null is not distinct from 'nan'::double v6,
+    'nan'::float is distinct from 'nan'::float v7,
+    'nan'::float is not distinct from 'nan'::float v8,
+    'nan'::float is not distinct from null v9,
+    'nan'::float is distinct from null v10,
+    null is distinct from 'nan'::float v11,
+    null is not distinct from 'nan'::float v12,
+    1::float is distinct from 2::float v13,
+    'nan'::float is distinct from 1::float v14,
+    'nan'::float is not distinct from 1::float v15,
+    1::float is not distinct from null v16,
+    1::float is distinct from null v17
+;
+----
+false true false true true false false true false true true false true true 
false false true
+
+
diff --git a/datafusion/physical-expr/src/expressions/binary.rs 
b/datafusion/physical-expr/src/expressions/binary.rs
index d1b52c10f..d0e7d12ae 100644
--- a/datafusion/physical-expr/src/expressions/binary.rs
+++ b/datafusion/physical-expr/src/expressions/binary.rs
@@ -60,11 +60,12 @@ use kernels::{
 use kernels_arrow::{
     add_decimal_dyn_scalar, add_dyn_decimal, divide_decimal_dyn_scalar,
     divide_dyn_opt_decimal, is_distinct_from, is_distinct_from_bool,
-    is_distinct_from_decimal, is_distinct_from_null, is_distinct_from_utf8,
-    is_not_distinct_from, is_not_distinct_from_bool, 
is_not_distinct_from_decimal,
-    is_not_distinct_from_null, is_not_distinct_from_utf8, modulus_decimal,
-    modulus_decimal_scalar, multiply_decimal_dyn_scalar, multiply_dyn_decimal,
-    subtract_decimal_dyn_scalar, subtract_dyn_decimal,
+    is_distinct_from_decimal, is_distinct_from_f32, is_distinct_from_f64,
+    is_distinct_from_null, is_distinct_from_utf8, is_not_distinct_from,
+    is_not_distinct_from_bool, is_not_distinct_from_decimal, 
is_not_distinct_from_f32,
+    is_not_distinct_from_f64, is_not_distinct_from_null, 
is_not_distinct_from_utf8,
+    modulus_decimal, modulus_decimal_scalar, multiply_decimal_dyn_scalar,
+    multiply_dyn_decimal, subtract_decimal_dyn_scalar, subtract_dyn_decimal,
 };
 
 use arrow::datatypes::{DataType, Schema, TimeUnit};
@@ -183,16 +184,44 @@ macro_rules! compute_decimal_op {
     }};
 }
 
+macro_rules! compute_f32_op {
+    ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
+        let ll = $LEFT
+            .as_any()
+            .downcast_ref::<$DT>()
+            .expect("compute_op failed to downcast left side array");
+        let rr = $RIGHT
+            .as_any()
+            .downcast_ref::<$DT>()
+            .expect("compute_op failed to downcast right side array");
+        Ok(Arc::new(paste::expr! {[<$OP _f32>]}(ll, rr)?))
+    }};
+}
+
+macro_rules! compute_f64_op {
+    ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
+        let ll = $LEFT
+            .as_any()
+            .downcast_ref::<$DT>()
+            .expect("compute_op failed to downcast left side array");
+        let rr = $RIGHT
+            .as_any()
+            .downcast_ref::<$DT>()
+            .expect("compute_op failed to downcast right side array");
+        Ok(Arc::new(paste::expr! {[<$OP _f64>]}(ll, rr)?))
+    }};
+}
+
 macro_rules! compute_null_op {
     ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
         let ll = $LEFT
             .as_any()
             .downcast_ref::<$DT>()
-            .expect("compute_op failed to downcast array");
+            .expect("compute_op failed to downcast left side array");
         let rr = $RIGHT
             .as_any()
             .downcast_ref::<$DT>()
-            .expect("compute_op failed to downcast array");
+            .expect("compute_op failed to downcast right side array");
         Ok(Arc::new(paste::expr! {[<$OP _null>]}(&ll, &rr)?))
     }};
 }
@@ -203,11 +232,11 @@ macro_rules! compute_utf8_op {
         let ll = $LEFT
             .as_any()
             .downcast_ref::<$DT>()
-            .expect("compute_op failed to downcast array");
+            .expect("compute_op failed to downcast left side array");
         let rr = $RIGHT
             .as_any()
             .downcast_ref::<$DT>()
-            .expect("compute_op failed to downcast array");
+            .expect("compute_op failed to downcast right side array");
         Ok(Arc::new(paste::expr! {[<$OP _utf8>]}(&ll, &rr)?))
     }};
 }
@@ -218,7 +247,7 @@ macro_rules! compute_utf8_op_scalar {
         let ll = $LEFT
             .as_any()
             .downcast_ref::<$DT>()
-            .expect("compute_op failed to downcast array");
+            .expect("compute_op failed to downcast left side array");
         if let ScalarValue::Utf8(Some(string_value)) = $RIGHT {
             Ok(Arc::new(paste::expr! {[<$OP _utf8_scalar>]}(
                 &ll,
@@ -317,7 +346,7 @@ macro_rules! compute_op_scalar {
             let ll = $LEFT
                 .as_any()
                 .downcast_ref::<$DT>()
-                .expect("compute_op failed to downcast array");
+                .expect("compute_op failed to downcast left side array");
             Ok(Arc::new(paste::expr! {[<$OP _scalar>]}(
                 &ll,
                 $RIGHT.try_into()?,
@@ -391,11 +420,11 @@ macro_rules! compute_op {
         let ll = $LEFT
             .as_any()
             .downcast_ref::<$DT>()
-            .expect("compute_op failed to downcast array");
+            .expect("compute_op failed to downcast left side array");
         let rr = $RIGHT
             .as_any()
             .downcast_ref::<$DT>()
-            .expect("compute_op failed to downcast array");
+            .expect("compute_op failed to downcast right side array");
         Ok(Arc::new($OP(&ll, &rr)?))
     }};
     // invoke unary operator
@@ -539,8 +568,8 @@ macro_rules! binary_array_op {
             DataType::UInt16 => compute_op!($LEFT, $RIGHT, $OP, UInt16Array),
             DataType::UInt32 => compute_op!($LEFT, $RIGHT, $OP, UInt32Array),
             DataType::UInt64 => compute_op!($LEFT, $RIGHT, $OP, UInt64Array),
-            DataType::Float32 => compute_op!($LEFT, $RIGHT, $OP, Float32Array),
-            DataType::Float64 => compute_op!($LEFT, $RIGHT, $OP, Float64Array),
+            DataType::Float32 => compute_f32_op!($LEFT, $RIGHT, $OP, 
Float32Array),
+            DataType::Float64 => compute_f64_op!($LEFT, $RIGHT, $OP, 
Float64Array),
             DataType::Utf8 => compute_utf8_op!($LEFT, $RIGHT, $OP, 
StringArray),
             DataType::Timestamp(TimeUnit::Nanosecond, _) => {
                 compute_op!($LEFT, $RIGHT, $OP, TimestampNanosecondArray)
diff --git a/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs 
b/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
index 365a7de9b..6ef515b33 100644
--- a/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
+++ b/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
@@ -32,6 +32,14 @@ use std::sync::Arc;
 // Simple (low performance) kernels until optimized kernels are added to arrow
 // See https://github.com/apache/arrow-rs/issues/960
 
+macro_rules! distinct_float {
+    ($LEFT:expr, $RIGHT:expr, $LEFT_ISNULL:expr, $RIGHT_ISNULL:expr) => {{
+        $LEFT_ISNULL != $RIGHT_ISNULL
+            || $LEFT.is_nan() != $RIGHT.is_nan()
+            || (!$LEFT.is_nan() && !$RIGHT.is_nan() && $LEFT != $RIGHT)
+    }};
+}
+
 pub(crate) fn is_distinct_from_bool(
     left: &BooleanArray,
     right: &BooleanArray,
@@ -62,22 +70,13 @@ pub(crate) fn is_distinct_from<T>(
 where
     T: ArrowNumericType,
 {
-    let left_data = left.data();
-    let right_data = right.data();
-    let array_len = left_data.len().min(right_data.len());
-
-    let left_values = left.values();
-    let right_values = right.values();
-
-    let distinct = arrow_buffer::MutableBuffer::collect_bool(array_len, |i| {
-        left_data.is_null(i) != right_data.is_null(i) || left_values[i] != 
right_values[i]
-    });
-
-    let array_data = ArrayData::builder(arrow_schema::DataType::Boolean)
-        .len(array_len)
-        .add_buffer(distinct.into());
-
-    Ok(BooleanArray::from(unsafe { array_data.build_unchecked() }))
+    distinct(
+        left,
+        right,
+        |left_value, right_value, left_isnull, right_isnull| {
+            left_isnull != right_isnull || left_value != right_value
+        },
+    )
 }
 
 pub(crate) fn is_not_distinct_from<T>(
@@ -87,18 +86,45 @@ pub(crate) fn is_not_distinct_from<T>(
 where
     T: ArrowNumericType,
 {
-    let left_data = left.data();
-    let right_data = right.data();
-    let array_len = left_data.len().min(right_data.len());
+    distinct(
+        left,
+        right,
+        |left_value, right_value, left_isnull, right_isnull| {
+            !(left_isnull != right_isnull || left_value != right_value)
+        },
+    )
+}
 
+fn distinct<
+    T,
+    F: FnMut(
+        <T as ArrowPrimitiveType>::Native,
+        <T as ArrowPrimitiveType>::Native,
+        bool,
+        bool,
+    ) -> bool,
+>(
+    left: &PrimitiveArray<T>,
+    right: &PrimitiveArray<T>,
+    mut op: F,
+) -> Result<BooleanArray>
+where
+    T: ArrowNumericType,
+{
     let left_values = left.values();
     let right_values = right.values();
+    let left_data = left.data();
+    let right_data = right.data();
 
+    let array_len = left_data.len().min(right_data.len());
     let distinct = arrow_buffer::MutableBuffer::collect_bool(array_len, |i| {
-        !(left_data.is_null(i) != right_data.is_null(i)
-            || left_values[i] != right_values[i])
+        op(
+            left_values[i],
+            right_values[i],
+            left_data.is_null(i),
+            right_data.is_null(i),
+        )
     });
-
     let array_data = ArrayData::builder(arrow_schema::DataType::Boolean)
         .len(array_len)
         .add_buffer(distinct.into());
@@ -106,6 +132,58 @@ where
     Ok(BooleanArray::from(unsafe { array_data.build_unchecked() }))
 }
 
+pub(crate) fn is_distinct_from_f32(
+    left: &Float32Array,
+    right: &Float32Array,
+) -> Result<BooleanArray> {
+    distinct(
+        left,
+        right,
+        |left_value, right_value, left_isnull, right_isnull| {
+            distinct_float!(left_value, right_value, left_isnull, right_isnull)
+        },
+    )
+}
+
+pub(crate) fn is_not_distinct_from_f32(
+    left: &Float32Array,
+    right: &Float32Array,
+) -> Result<BooleanArray> {
+    distinct(
+        left,
+        right,
+        |left_value, right_value, left_isnull, right_isnull| {
+            !(distinct_float!(left_value, right_value, left_isnull, 
right_isnull))
+        },
+    )
+}
+
+pub(crate) fn is_distinct_from_f64(
+    left: &Float64Array,
+    right: &Float64Array,
+) -> Result<BooleanArray> {
+    distinct(
+        left,
+        right,
+        |left_value, right_value, left_isnull, right_isnull| {
+            distinct_float!(left_value, right_value, left_isnull, right_isnull)
+        },
+    )
+}
+
+pub(crate) fn is_not_distinct_from_f64(
+    left: &Float64Array,
+    right: &Float64Array,
+) -> Result<BooleanArray> {
+    distinct(
+        left,
+        right,
+        |left_value, right_value, left_isnull, right_isnull| {
+            !(distinct_float!(left_value, right_value, left_isnull, 
right_isnull))
+        },
+    )
+}
+
 pub(crate) fn is_distinct_from_utf8<OffsetSize: OffsetSizeTrait>(
     left: &GenericStringArray<OffsetSize>,
     right: &GenericStringArray<OffsetSize>,

Reply via email to