This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 8de7ea45e Fix is_distinct from for float NaN values (#5446)
8de7ea45e is described below
commit 8de7ea45e96bfc1931ae51acb5a3abed7511dc7b
Author: comphead <[email protected]>
AuthorDate: Fri Mar 3 04:25:11 2023 -0800
Fix is_distinct from for float NaN values (#5446)
* impl is_distinct for floats
* impl is_distinct for floats.comments
---
.../test_files/pg_compat/pg_compat_simple.slt | 18 +++
.../core/tests/sqllogictests/test_files/select.slt | 26 +++++
datafusion/physical-expr/src/expressions/binary.rs | 59 +++++++---
.../src/expressions/binary/kernels_arrow.rs | 122 +++++++++++++++++----
4 files changed, 188 insertions(+), 37 deletions(-)
diff --git
a/datafusion/core/tests/sqllogictests/test_files/pg_compat/pg_compat_simple.slt
b/datafusion/core/tests/sqllogictests/test_files/pg_compat/pg_compat_simple.slt
index 6026d07b6..5e9eb07f7 100644
---
a/datafusion/core/tests/sqllogictests/test_files/pg_compat/pg_compat_simple.slt
+++
b/datafusion/core/tests/sqllogictests/test_files/pg_compat/pg_compat_simple.slt
@@ -572,6 +572,24 @@ e 4 97 -13181 2047637360 6176835796788944083 158 53000
2042457019 97260165026400
e 5 -86 32514 -467659022 -8012578250188146150 254 2684 2861911482
2126626171973341689 0.12559289 0.014793053078 gxfHWUF8XgY2KdFxigxvNEXe2V2XMl
e 5 64 -26526 1689098844 8950618259486183091 224 45253 662099130
16127995415060805595 0.2897315 0.575945048386 56MZa5O1hVtX4c5sbnCfxuX5kDChqI
+# distinct_from logic for floats
+query BBBBBBBBBBB
+select
+ 'nan'::float is distinct from 'nan'::float v7,
+ 'nan'::float is not distinct from 'nan'::float v8,
+ 'nan'::float is not distinct from null v9,
+ 'nan'::float is distinct from null v10,
+ null is distinct from 'nan'::float v11,
+ null is not distinct from 'nan'::float v12,
+ 1::float is distinct from 2::float v13,
+ 'nan'::float is distinct from 1::float v14,
+ 'nan'::float is not distinct from 1::float v15,
+ 1::float is not distinct from null v16,
+ 1::float is distinct from null v17
+;
+----
+false true false true true false true true false false true
+
########
# Clean up after the test
########
diff --git a/datafusion/core/tests/sqllogictests/test_files/select.slt
b/datafusion/core/tests/sqllogictests/test_files/select.slt
index a1bc52431..a3fcb8f70 100644
--- a/datafusion/core/tests/sqllogictests/test_files/select.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/select.slt
@@ -80,3 +80,29 @@ select '1' from foo order by column1;
# foo distinct order by
statement error DataFusion error: Error during planning: For SELECT DISTINCT,
ORDER BY expressions column1 must appear in select list
select distinct '1' from foo order by column1;
+
+# distincts for float nan
+query BBBBBBBBBBBBBBBBB
+select
+ 'nan'::double is distinct from 'nan'::double v1,
+ 'nan'::double is not distinct from 'nan'::double v2,
+ 'nan'::double is not distinct from null v3,
+ 'nan'::double is distinct from null v4,
+ null is distinct from 'nan'::double v5,
+ null is not distinct from 'nan'::double v6,
+ 'nan'::float is distinct from 'nan'::float v7,
+ 'nan'::float is not distinct from 'nan'::float v8,
+ 'nan'::float is not distinct from null v9,
+ 'nan'::float is distinct from null v10,
+ null is distinct from 'nan'::float v11,
+ null is not distinct from 'nan'::float v12,
+ 1::float is distinct from 2::float v13,
+ 'nan'::float is distinct from 1::float v14,
+ 'nan'::float is not distinct from 1::float v15,
+ 1::float is not distinct from null v16,
+ 1::float is distinct from null v17
+;
+----
+false true false true true false false true false true true false true true
false false true
+
+
diff --git a/datafusion/physical-expr/src/expressions/binary.rs
b/datafusion/physical-expr/src/expressions/binary.rs
index d1b52c10f..d0e7d12ae 100644
--- a/datafusion/physical-expr/src/expressions/binary.rs
+++ b/datafusion/physical-expr/src/expressions/binary.rs
@@ -60,11 +60,12 @@ use kernels::{
use kernels_arrow::{
add_decimal_dyn_scalar, add_dyn_decimal, divide_decimal_dyn_scalar,
divide_dyn_opt_decimal, is_distinct_from, is_distinct_from_bool,
- is_distinct_from_decimal, is_distinct_from_null, is_distinct_from_utf8,
- is_not_distinct_from, is_not_distinct_from_bool,
is_not_distinct_from_decimal,
- is_not_distinct_from_null, is_not_distinct_from_utf8, modulus_decimal,
- modulus_decimal_scalar, multiply_decimal_dyn_scalar, multiply_dyn_decimal,
- subtract_decimal_dyn_scalar, subtract_dyn_decimal,
+ is_distinct_from_decimal, is_distinct_from_f32, is_distinct_from_f64,
+ is_distinct_from_null, is_distinct_from_utf8, is_not_distinct_from,
+ is_not_distinct_from_bool, is_not_distinct_from_decimal,
is_not_distinct_from_f32,
+ is_not_distinct_from_f64, is_not_distinct_from_null,
is_not_distinct_from_utf8,
+ modulus_decimal, modulus_decimal_scalar, multiply_decimal_dyn_scalar,
+ multiply_dyn_decimal, subtract_decimal_dyn_scalar, subtract_dyn_decimal,
};
use arrow::datatypes::{DataType, Schema, TimeUnit};
@@ -183,16 +184,44 @@ macro_rules! compute_decimal_op {
}};
}
+macro_rules! compute_f32_op {
+ ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
+ let ll = $LEFT
+ .as_any()
+ .downcast_ref::<$DT>()
+ .expect("compute_op failed to downcast left side array");
+ let rr = $RIGHT
+ .as_any()
+ .downcast_ref::<$DT>()
+ .expect("compute_op failed to downcast right side array");
+ Ok(Arc::new(paste::expr! {[<$OP _f32>]}(ll, rr)?))
+ }};
+}
+
+macro_rules! compute_f64_op {
+ ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
+ let ll = $LEFT
+ .as_any()
+ .downcast_ref::<$DT>()
+ .expect("compute_op failed to downcast left side array");
+ let rr = $RIGHT
+ .as_any()
+ .downcast_ref::<$DT>()
+ .expect("compute_op failed to downcast right side array");
+ Ok(Arc::new(paste::expr! {[<$OP _f64>]}(ll, rr)?))
+ }};
+}
+
macro_rules! compute_null_op {
($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
let ll = $LEFT
.as_any()
.downcast_ref::<$DT>()
- .expect("compute_op failed to downcast array");
+ .expect("compute_op failed to downcast left side array");
let rr = $RIGHT
.as_any()
.downcast_ref::<$DT>()
- .expect("compute_op failed to downcast array");
+ .expect("compute_op failed to downcast right side array");
Ok(Arc::new(paste::expr! {[<$OP _null>]}(&ll, &rr)?))
}};
}
@@ -203,11 +232,11 @@ macro_rules! compute_utf8_op {
let ll = $LEFT
.as_any()
.downcast_ref::<$DT>()
- .expect("compute_op failed to downcast array");
+ .expect("compute_op failed to downcast left side array");
let rr = $RIGHT
.as_any()
.downcast_ref::<$DT>()
- .expect("compute_op failed to downcast array");
+ .expect("compute_op failed to downcast right side array");
Ok(Arc::new(paste::expr! {[<$OP _utf8>]}(&ll, &rr)?))
}};
}
@@ -218,7 +247,7 @@ macro_rules! compute_utf8_op_scalar {
let ll = $LEFT
.as_any()
.downcast_ref::<$DT>()
- .expect("compute_op failed to downcast array");
+ .expect("compute_op failed to downcast left side array");
if let ScalarValue::Utf8(Some(string_value)) = $RIGHT {
Ok(Arc::new(paste::expr! {[<$OP _utf8_scalar>]}(
&ll,
@@ -317,7 +346,7 @@ macro_rules! compute_op_scalar {
let ll = $LEFT
.as_any()
.downcast_ref::<$DT>()
- .expect("compute_op failed to downcast array");
+ .expect("compute_op failed to downcast left side array");
Ok(Arc::new(paste::expr! {[<$OP _scalar>]}(
&ll,
$RIGHT.try_into()?,
@@ -391,11 +420,11 @@ macro_rules! compute_op {
let ll = $LEFT
.as_any()
.downcast_ref::<$DT>()
- .expect("compute_op failed to downcast array");
+ .expect("compute_op failed to downcast left side array");
let rr = $RIGHT
.as_any()
.downcast_ref::<$DT>()
- .expect("compute_op failed to downcast array");
+ .expect("compute_op failed to downcast right side array");
Ok(Arc::new($OP(&ll, &rr)?))
}};
// invoke unary operator
@@ -539,8 +568,8 @@ macro_rules! binary_array_op {
DataType::UInt16 => compute_op!($LEFT, $RIGHT, $OP, UInt16Array),
DataType::UInt32 => compute_op!($LEFT, $RIGHT, $OP, UInt32Array),
DataType::UInt64 => compute_op!($LEFT, $RIGHT, $OP, UInt64Array),
- DataType::Float32 => compute_op!($LEFT, $RIGHT, $OP, Float32Array),
- DataType::Float64 => compute_op!($LEFT, $RIGHT, $OP, Float64Array),
+ DataType::Float32 => compute_f32_op!($LEFT, $RIGHT, $OP,
Float32Array),
+ DataType::Float64 => compute_f64_op!($LEFT, $RIGHT, $OP,
Float64Array),
DataType::Utf8 => compute_utf8_op!($LEFT, $RIGHT, $OP,
StringArray),
DataType::Timestamp(TimeUnit::Nanosecond, _) => {
compute_op!($LEFT, $RIGHT, $OP, TimestampNanosecondArray)
diff --git a/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
b/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
index 365a7de9b..6ef515b33 100644
--- a/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
+++ b/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
@@ -32,6 +32,14 @@ use std::sync::Arc;
// Simple (low performance) kernels until optimized kernels are added to arrow
// See https://github.com/apache/arrow-rs/issues/960
+macro_rules! distinct_float {
+ ($LEFT:expr, $RIGHT:expr, $LEFT_ISNULL:expr, $RIGHT_ISNULL:expr) => {{
+ $LEFT_ISNULL != $RIGHT_ISNULL
+ || $LEFT.is_nan() != $RIGHT.is_nan()
+ || (!$LEFT.is_nan() && !$RIGHT.is_nan() && $LEFT != $RIGHT)
+ }};
+}
+
pub(crate) fn is_distinct_from_bool(
left: &BooleanArray,
right: &BooleanArray,
@@ -62,22 +70,13 @@ pub(crate) fn is_distinct_from<T>(
where
T: ArrowNumericType,
{
- let left_data = left.data();
- let right_data = right.data();
- let array_len = left_data.len().min(right_data.len());
-
- let left_values = left.values();
- let right_values = right.values();
-
- let distinct = arrow_buffer::MutableBuffer::collect_bool(array_len, |i| {
- left_data.is_null(i) != right_data.is_null(i) || left_values[i] !=
right_values[i]
- });
-
- let array_data = ArrayData::builder(arrow_schema::DataType::Boolean)
- .len(array_len)
- .add_buffer(distinct.into());
-
- Ok(BooleanArray::from(unsafe { array_data.build_unchecked() }))
+ distinct(
+ left,
+ right,
+ |left_value, right_value, left_isnull, right_isnull| {
+ left_isnull != right_isnull || left_value != right_value
+ },
+ )
}
pub(crate) fn is_not_distinct_from<T>(
@@ -87,18 +86,45 @@ pub(crate) fn is_not_distinct_from<T>(
where
T: ArrowNumericType,
{
- let left_data = left.data();
- let right_data = right.data();
- let array_len = left_data.len().min(right_data.len());
+ distinct(
+ left,
+ right,
+ |left_value, right_value, left_isnull, right_isnull| {
+ !(left_isnull != right_isnull || left_value != right_value)
+ },
+ )
+}
+fn distinct<
+ T,
+ F: FnMut(
+ <T as ArrowPrimitiveType>::Native,
+ <T as ArrowPrimitiveType>::Native,
+ bool,
+ bool,
+ ) -> bool,
+>(
+ left: &PrimitiveArray<T>,
+ right: &PrimitiveArray<T>,
+ mut op: F,
+) -> Result<BooleanArray>
+where
+ T: ArrowNumericType,
+{
let left_values = left.values();
let right_values = right.values();
+ let left_data = left.data();
+ let right_data = right.data();
+ let array_len = left_data.len().min(right_data.len());
let distinct = arrow_buffer::MutableBuffer::collect_bool(array_len, |i| {
- !(left_data.is_null(i) != right_data.is_null(i)
- || left_values[i] != right_values[i])
+ op(
+ left_values[i],
+ right_values[i],
+ left_data.is_null(i),
+ right_data.is_null(i),
+ )
});
-
let array_data = ArrayData::builder(arrow_schema::DataType::Boolean)
.len(array_len)
.add_buffer(distinct.into());
@@ -106,6 +132,58 @@ where
Ok(BooleanArray::from(unsafe { array_data.build_unchecked() }))
}
+pub(crate) fn is_distinct_from_f32(
+ left: &Float32Array,
+ right: &Float32Array,
+) -> Result<BooleanArray> {
+ distinct(
+ left,
+ right,
+ |left_value, right_value, left_isnull, right_isnull| {
+ distinct_float!(left_value, right_value, left_isnull, right_isnull)
+ },
+ )
+}
+
+pub(crate) fn is_not_distinct_from_f32(
+ left: &Float32Array,
+ right: &Float32Array,
+) -> Result<BooleanArray> {
+ distinct(
+ left,
+ right,
+ |left_value, right_value, left_isnull, right_isnull| {
+ !(distinct_float!(left_value, right_value, left_isnull,
right_isnull))
+ },
+ )
+}
+
+pub(crate) fn is_distinct_from_f64(
+ left: &Float64Array,
+ right: &Float64Array,
+) -> Result<BooleanArray> {
+ distinct(
+ left,
+ right,
+ |left_value, right_value, left_isnull, right_isnull| {
+ distinct_float!(left_value, right_value, left_isnull, right_isnull)
+ },
+ )
+}
+
+pub(crate) fn is_not_distinct_from_f64(
+ left: &Float64Array,
+ right: &Float64Array,
+) -> Result<BooleanArray> {
+ distinct(
+ left,
+ right,
+ |left_value, right_value, left_isnull, right_isnull| {
+ !(distinct_float!(left_value, right_value, left_isnull,
right_isnull))
+ },
+ )
+}
+
pub(crate) fn is_distinct_from_utf8<OffsetSize: OffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &GenericStringArray<OffsetSize>,