This is an automated email from the ASF dual-hosted git repository.

yjshen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new ac35f346c Numeric, String, Boolean comparisons with literal NULL 
(#2481)
ac35f346c is described below

commit ac35f346c3209e647b1d819b3590d3b5f4700b1b
Author: DuRipeng <[email protected]>
AuthorDate: Wed May 11 09:47:08 2022 +0800

    Numeric, String, Boolean comparisons with literal NULL (#2481)
---
 datafusion/core/tests/sql/expr.rs                  | 120 +++++++++++++++++++++
 datafusion/expr/src/binary_rule.rs                 |   2 +
 datafusion/physical-expr/src/expressions/binary.rs |  80 ++++++++------
 3 files changed, 167 insertions(+), 35 deletions(-)

diff --git a/datafusion/core/tests/sql/expr.rs 
b/datafusion/core/tests/sql/expr.rs
index c39a7a8cd..b8a002657 100644
--- a/datafusion/core/tests/sql/expr.rs
+++ b/datafusion/core/tests/sql/expr.rs
@@ -1201,3 +1201,123 @@ async fn nested_subquery() -> Result<()> {
     assert_batches_eq!(expected, &actual);
     Ok(())
 }
+
+#[tokio::test]
+async fn comparisons_with_null() -> Result<()> {
+    let ctx = SessionContext::new();
+    // 1. Numeric comparison with NULL
+    let sql = "select column1 < NULL from (VALUES (1, 'foo' ,2.3), (2, 'bar', 
5.4)) as t";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+-------------------+",
+        "| t.column1 Lt NULL |",
+        "+-------------------+",
+        "|                   |",
+        "|                   |",
+        "+-------------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
+    let sql =
+        "select column1 <= NULL from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) 
as t";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+---------------------+",
+        "| t.column1 LtEq NULL |",
+        "+---------------------+",
+        "|                     |",
+        "|                     |",
+        "+---------------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
+    let sql = "select column1 > NULL from (VALUES (1, 'foo' ,2.3), (2, 'bar', 
5.4)) as t";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+-------------------+",
+        "| t.column1 Gt NULL |",
+        "+-------------------+",
+        "|                   |",
+        "|                   |",
+        "+-------------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
+    let sql =
+        "select column1 >= NULL from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) 
as t";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+---------------------+",
+        "| t.column1 GtEq NULL |",
+        "+---------------------+",
+        "|                     |",
+        "|                     |",
+        "+---------------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
+    let sql = "select column1 = NULL from (VALUES (1, 'foo' ,2.3), (2, 'bar', 
5.4)) as t";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+-------------------+",
+        "| t.column1 Eq NULL |",
+        "+-------------------+",
+        "|                   |",
+        "|                   |",
+        "+-------------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
+    let sql =
+        "select column1 != NULL from (VALUES (1, 'foo' ,2.3), (2, 'bar', 5.4)) 
as t";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+----------------------+",
+        "| t.column1 NotEq NULL |",
+        "+----------------------+",
+        "|                      |",
+        "|                      |",
+        "+----------------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
+    // 1.1 Float value comparison with NULL
+    let sql = "select column3 < NULL from (VALUES (1, 'foo' ,2.3), (2, 'bar', 
5.4)) as t";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+-------------------+",
+        "| t.column3 Lt NULL |",
+        "+-------------------+",
+        "|                   |",
+        "|                   |",
+        "+-------------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
+    // String comparison with NULL
+    let sql = "select column2 < NULL from (VALUES (1, 'foo' ,2.3), (2, 'bar', 
5.4)) as t";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+-------------------+",
+        "| t.column2 Lt NULL |",
+        "+-------------------+",
+        "|                   |",
+        "|                   |",
+        "+-------------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
+    // Boolean comparison with NULL
+    let sql = "select column1 < NULL from (VALUES (true), (false)) as t";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+-------------------+",
+        "| t.column1 Lt NULL |",
+        "+-------------------+",
+        "|                   |",
+        "|                   |",
+        "+-------------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+    Ok(())
+}
diff --git a/datafusion/expr/src/binary_rule.rs 
b/datafusion/expr/src/binary_rule.rs
index 63a9712fd..f3d753467 100644
--- a/datafusion/expr/src/binary_rule.rs
+++ b/datafusion/expr/src/binary_rule.rs
@@ -161,6 +161,7 @@ fn comparison_eq_coercion(lhs_type: &DataType, rhs_type: 
&DataType) -> Option<Da
         .or_else(|| dictionary_coercion(lhs_type, rhs_type))
         .or_else(|| temporal_coercion(lhs_type, rhs_type))
         .or_else(|| string_coercion(lhs_type, rhs_type))
+        .or_else(|| null_coercion(lhs_type, rhs_type))
 }
 
 fn comparison_order_coercion(
@@ -177,6 +178,7 @@ fn comparison_order_coercion(
         .or_else(|| string_coercion(lhs_type, rhs_type))
         .or_else(|| dictionary_coercion(lhs_type, rhs_type))
         .or_else(|| temporal_coercion(lhs_type, rhs_type))
+        .or_else(|| null_coercion(lhs_type, rhs_type))
 }
 
 fn comparison_binary_numeric_coercion(
diff --git a/datafusion/physical-expr/src/expressions/binary.rs 
b/datafusion/physical-expr/src/expressions/binary.rs
index 060f30cb2..25beea5b2 100644
--- a/datafusion/physical-expr/src/expressions/binary.rs
+++ b/datafusion/physical-expr/src/expressions/binary.rs
@@ -657,17 +657,15 @@ macro_rules! compute_utf8_op_scalar {
 
 /// Invoke a compute kernel on a data array and a scalar value
 macro_rules! compute_utf8_op_dyn_scalar {
-    ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
+    ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{
         if let Some(string_value) = $RIGHT {
             Ok(Arc::new(paste::expr! {[<$OP _dyn_utf8_scalar>]}(
                 $LEFT,
                 &string_value,
             )?))
         } else {
-            Err(DataFusionError::Internal(format!(
-                "compute_utf8_op_scalar for '{}' failed with literal 'none' 
value",
-                stringify!($OP),
-            )))
+            // when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE
+            Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len())))
         }
     }};
 }
@@ -691,7 +689,7 @@ macro_rules! compute_bool_op_scalar {
 
 /// Invoke a compute kernel on a boolean data array and a scalar value
 macro_rules! compute_bool_op_dyn_scalar {
-    ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
+    ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{
         // generate the scalar function name, such as lt_dyn_bool_scalar, from 
the $OP parameter
         // (which could have a value of lt) and the suffix _scalar
         if let Some(b) = $RIGHT {
@@ -700,10 +698,8 @@ macro_rules! compute_bool_op_dyn_scalar {
                 b,
             )?))
         } else {
-            Err(DataFusionError::Internal(format!(
-                "compute_utf8_op_scalar for '{}' failed with literal 'none' 
value",
-                stringify!($OP),
-            )))
+            // when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE
+            Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len())))
         }
     }};
 }
@@ -751,8 +747,9 @@ macro_rules! compute_op_scalar {
 
 /// Invoke a dyn compute kernel on a data array and a scalar value
 /// LEFT is Primitive or Dictionart array of numeric values, RIGHT is scalar 
value
+/// OP_TYPE is the return type of scalar function
 macro_rules! compute_op_dyn_scalar {
-    ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
+    ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{
         // generate the scalar function name, such as lt_dyn_scalar, from the 
$OP parameter
         // (which could have a value of lt_dyn) and the suffix _scalar
         if let Some(value) = $RIGHT {
@@ -761,10 +758,8 @@ macro_rules! compute_op_dyn_scalar {
                 value,
             )?))
         } else {
-            Err(DataFusionError::Internal(format!(
-                "compute_utf8_op_scalar for '{}' failed with literal 'none' 
value",
-                stringify!($OP),
-            )))
+            // when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE
+            Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len())))
         }
     }};
 }
@@ -1125,22 +1120,22 @@ impl PhysicalExpr for BinaryExpr {
 /// such as Utf8 strings.
 #[macro_export]
 macro_rules! binary_array_op_dyn_scalar {
-    ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
+    ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{
         let result: Result<Arc<dyn Array>> = match $RIGHT {
-            ScalarValue::Boolean(b) => compute_bool_op_dyn_scalar!($LEFT, b, 
$OP),
+            ScalarValue::Boolean(b) => compute_bool_op_dyn_scalar!($LEFT, b, 
$OP, $OP_TYPE),
             ScalarValue::Decimal128(..) => compute_decimal_op_scalar!($LEFT, 
$RIGHT, $OP, DecimalArray),
-            ScalarValue::Utf8(v) => compute_utf8_op_dyn_scalar!($LEFT, v, $OP),
-            ScalarValue::LargeUtf8(v) => compute_utf8_op_dyn_scalar!($LEFT, v, 
$OP),
-            ScalarValue::Int8(v) => compute_op_dyn_scalar!($LEFT, v, $OP),
-            ScalarValue::Int16(v) => compute_op_dyn_scalar!($LEFT, v, $OP),
-            ScalarValue::Int32(v) => compute_op_dyn_scalar!($LEFT, v, $OP),
-            ScalarValue::Int64(v) => compute_op_dyn_scalar!($LEFT, v, $OP),
-            ScalarValue::UInt8(v) => compute_op_dyn_scalar!($LEFT, v, $OP),
-            ScalarValue::UInt16(v) => compute_op_dyn_scalar!($LEFT, v, $OP),
-            ScalarValue::UInt32(v) => compute_op_dyn_scalar!($LEFT, v, $OP),
-            ScalarValue::UInt64(v) => compute_op_dyn_scalar!($LEFT, v, $OP),
-            ScalarValue::Float32(_) => compute_op_scalar!($LEFT, $RIGHT, $OP, 
Float32Array),
-            ScalarValue::Float64(_) => compute_op_scalar!($LEFT, $RIGHT, $OP, 
Float64Array),
+            ScalarValue::Utf8(v) => compute_utf8_op_dyn_scalar!($LEFT, v, $OP, 
$OP_TYPE),
+            ScalarValue::LargeUtf8(v) => compute_utf8_op_dyn_scalar!($LEFT, v, 
$OP, $OP_TYPE),
+            ScalarValue::Int8(v) => compute_op_dyn_scalar!($LEFT, v, $OP, 
$OP_TYPE),
+            ScalarValue::Int16(v) => compute_op_dyn_scalar!($LEFT, v, $OP, 
$OP_TYPE),
+            ScalarValue::Int32(v) => compute_op_dyn_scalar!($LEFT, v, $OP, 
$OP_TYPE),
+            ScalarValue::Int64(v) => compute_op_dyn_scalar!($LEFT, v, $OP, 
$OP_TYPE),
+            ScalarValue::UInt8(v) => compute_op_dyn_scalar!($LEFT, v, $OP, 
$OP_TYPE),
+            ScalarValue::UInt16(v) => compute_op_dyn_scalar!($LEFT, v, $OP, 
$OP_TYPE),
+            ScalarValue::UInt32(v) => compute_op_dyn_scalar!($LEFT, v, $OP, 
$OP_TYPE),
+            ScalarValue::UInt64(v) => compute_op_dyn_scalar!($LEFT, v, $OP, 
$OP_TYPE),
+            ScalarValue::Float32(v) => compute_op_dyn_scalar!($LEFT, v, $OP, 
$OP_TYPE),
+            ScalarValue::Float64(v) => compute_op_dyn_scalar!($LEFT, v, $OP, 
$OP_TYPE),
             ScalarValue::Date32(_) => compute_op_scalar!($LEFT, $RIGHT, $OP, 
Date32Array),
             ScalarValue::Date64(_) => compute_op_scalar!($LEFT, $RIGHT, $OP, 
Date64Array),
             ScalarValue::TimestampSecond(..) => compute_op_scalar!($LEFT, 
$RIGHT, $OP, TimestampSecondArray),
@@ -1163,22 +1158,37 @@ impl BinaryExpr {
     ) -> Result<Option<Result<ArrayRef>>> {
         let scalar_result = match &self.op {
             Operator::Lt => {
-                binary_array_op_dyn_scalar!(array, scalar.clone(), lt)
+                binary_array_op_dyn_scalar!(array, scalar.clone(), lt, 
&DataType::Boolean)
             }
             Operator::LtEq => {
-                binary_array_op_dyn_scalar!(array, scalar.clone(), lt_eq)
+                binary_array_op_dyn_scalar!(
+                    array,
+                    scalar.clone(),
+                    lt_eq,
+                    &DataType::Boolean
+                )
             }
             Operator::Gt => {
-                binary_array_op_dyn_scalar!(array, scalar.clone(), gt)
+                binary_array_op_dyn_scalar!(array, scalar.clone(), gt, 
&DataType::Boolean)
             }
             Operator::GtEq => {
-                binary_array_op_dyn_scalar!(array, scalar.clone(), gt_eq)
+                binary_array_op_dyn_scalar!(
+                    array,
+                    scalar.clone(),
+                    gt_eq,
+                    &DataType::Boolean
+                )
             }
             Operator::Eq => {
-                binary_array_op_dyn_scalar!(array, scalar.clone(), eq)
+                binary_array_op_dyn_scalar!(array, scalar.clone(), eq, 
&DataType::Boolean)
             }
             Operator::NotEq => {
-                binary_array_op_dyn_scalar!(array, scalar.clone(), neq)
+                binary_array_op_dyn_scalar!(
+                    array,
+                    scalar.clone(),
+                    neq,
+                    &DataType::Boolean
+                )
             }
             Operator::Like => {
                 binary_string_array_op_scalar!(array, scalar.clone(), like)

Reply via email to