This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 00850a4  Support `=`, `<`, `<=`, `>`, `>=`, `!=`, `is distinct from`, 
`is not distinct from` for `BooleanArray` (#1163)
00850a4 is described below

commit 00850a4f6ab63aa943cea2f00d7ede74d8730c1b
Author: Andrew Lamb <[email protected]>
AuthorDate: Sat Nov 20 07:59:30 2021 -0500

    Support `=`, `<`, `<=`, `>`, `>=`, `!=`, `is distinct from`, `is not 
distinct from` for `BooleanArray` (#1163)
    
    * Support `=`, `<`, `<=`, `>`, `>=`, `!=`, `is distinct from`, `is not 
distinct from` for `BooleanArray`
    
    * Update datafusion/src/physical_plan/expressions/binary.rs
    
    Co-authored-by: rdettai <[email protected]>
    
    * Apply suggestions from code review
    
    Co-authored-by: Jiayu Liu <[email protected]>
    
    Co-authored-by: rdettai <[email protected]>
    Co-authored-by: Jiayu Liu <[email protected]>
---
 datafusion/src/physical_optimizer/pruning.rs       |  24 +-
 datafusion/src/physical_plan/expressions/binary.rs | 430 ++++++++++++++++++++-
 datafusion/src/physical_plan/expressions/nullif.rs |   4 +-
 .../src/physical_plan/file_format/parquet.rs       |  12 +-
 datafusion/tests/sql.rs                            | 147 +++++++
 5 files changed, 592 insertions(+), 25 deletions(-)

diff --git a/datafusion/src/physical_optimizer/pruning.rs 
b/datafusion/src/physical_optimizer/pruning.rs
index ba01865..24334d7 100644
--- a/datafusion/src/physical_optimizer/pruning.rs
+++ b/datafusion/src/physical_optimizer/pruning.rs
@@ -1374,36 +1374,24 @@ mod tests {
 
     #[test]
     fn prune_bool_column_eq_true() {
-        let (schema, statistics, _, _) = bool_setup();
+        let (schema, statistics, expected_true, _) = bool_setup();
 
         // b1 = true
         let expr = col("b1").eq(lit(true));
         let p = PruningPredicate::try_new(&expr, schema).unwrap();
-        let result = p.prune(&statistics).unwrap_err();
-        assert!(
-            result.to_string().contains(
-                "Data type Boolean not supported for scalar operation 'lt_eq' 
on dyn array"
-            ),
-            "{}",
-            result
-        )
+        let result = p.prune(&statistics).unwrap();
+        assert_eq!(result, expected_true);
     }
 
     #[test]
     fn prune_bool_not_column_eq_true() {
-        let (schema, statistics, _, _) = bool_setup();
+        let (schema, statistics, _, expected_false) = bool_setup();
 
         // !b1 = true
         let expr = col("b1").not().eq(lit(true));
         let p = PruningPredicate::try_new(&expr, schema).unwrap();
-        let result = p.prune(&statistics).unwrap_err();
-        assert!(
-            result.to_string().contains(
-                "Data type Boolean not supported for scalar operation 'lt_eq' 
on dyn array"
-            ),
-            "{}",
-            result
-        )
+        let result = p.prune(&statistics).unwrap();
+        assert_eq!(result, expected_false);
     }
 
     /// Creates setup for int32 chunk pruning
diff --git a/datafusion/src/physical_plan/expressions/binary.rs 
b/datafusion/src/physical_plan/expressions/binary.rs
index 456e8d4..92d2a8b 100644
--- a/datafusion/src/physical_plan/expressions/binary.rs
+++ b/datafusion/src/physical_plan/expressions/binary.rs
@@ -25,6 +25,10 @@ use arrow::compute::kernels::arithmetic::{
 use arrow::compute::kernels::boolean::{and_kleene, not, or_kleene};
 use arrow::compute::kernels::comparison::{eq, gt, gt_eq, lt, lt_eq, neq};
 use arrow::compute::kernels::comparison::{
+    eq_bool, eq_bool_scalar, gt_bool, gt_eq_bool, lt_bool, lt_eq_bool, 
neq_bool,
+    neq_bool_scalar,
+};
+use arrow::compute::kernels::comparison::{
     eq_scalar, gt_eq_scalar, gt_scalar, lt_eq_scalar, lt_scalar, neq_scalar,
 };
 use arrow::compute::kernels::comparison::{
@@ -49,6 +53,64 @@ use super::coercion::{
     eq_coercion, like_coercion, numerical_coercion, order_coercion, 
string_coercion,
 };
 
+// Simple (low performance) kernels until optimized kernels are added to arrow
+// See https://github.com/apache/arrow-rs/issues/960
+
+fn is_distinct_from_bool(
+    left: &BooleanArray,
+    right: &BooleanArray,
+) -> Result<BooleanArray> {
+    // Different from `neq_bool` because `null is distinct from null` is false 
and not null
+    Ok(left
+        .iter()
+        .zip(right.iter())
+        .map(|(left, right)| Some(left != right))
+        .collect())
+}
+
+fn is_not_distinct_from_bool(
+    left: &BooleanArray,
+    right: &BooleanArray,
+) -> Result<BooleanArray> {
+    Ok(left
+        .iter()
+        .zip(right.iter())
+        .map(|(left, right)| Some(left == right))
+        .collect())
+}
+
+// TODO use arrow-rs kernels when available. See
+// https://github.com/apache/arrow-rs/issues/959
+#[allow(clippy::bool_comparison)]
+fn lt_bool_scalar(left: &BooleanArray, right: bool) -> Result<BooleanArray> {
+    Ok(left
+        .iter()
+        .map(|left| left.map(|left| left < right))
+        .collect())
+}
+
+fn lt_eq_bool_scalar(left: &BooleanArray, right: bool) -> Result<BooleanArray> 
{
+    Ok(left
+        .iter()
+        .map(|left| left.map(|left| left <= right))
+        .collect())
+}
+
+#[allow(clippy::bool_comparison)]
+fn gt_bool_scalar(left: &BooleanArray, right: bool) -> Result<BooleanArray> {
+    Ok(left
+        .iter()
+        .map(|left| left.map(|left| left > right))
+        .collect())
+}
+
+fn gt_eq_bool_scalar(left: &BooleanArray, right: bool) -> Result<BooleanArray> 
{
+    Ok(left
+        .iter()
+        .map(|left| left.map(|left| left >= right))
+        .collect())
+}
+
 /// Binary expression
 #[derive(Debug)]
 pub struct BinaryExpr {
@@ -126,6 +188,47 @@ macro_rules! compute_utf8_op_scalar {
     }};
 }
 
+/// Invoke a compute kernel on a boolean data array and a scalar value
+macro_rules! compute_bool_op_scalar {
+    ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
+        use std::convert::TryInto;
+        let ll = $LEFT
+            .as_any()
+            .downcast_ref::<$DT>()
+            .expect("compute_op failed to downcast array");
+        // generate the scalar function name, such as lt_scalar, from the $OP 
parameter
+        // (which could have a value of lt) and the suffix _scalar
+        Ok(Arc::new(paste::expr! {[<$OP _bool_scalar>]}(
+            &ll,
+            $RIGHT.try_into()?,
+        )?))
+    }};
+}
+
+/// Invoke a bool compute kernel on array(s)
+macro_rules! compute_bool_op {
+    // invoke binary operator
+    ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
+        let ll = $LEFT
+            .as_any()
+            .downcast_ref::<$DT>()
+            .expect("compute_op failed to downcast left side array");
+        let rr = $RIGHT
+            .as_any()
+            .downcast_ref::<$DT>()
+            .expect("compute_op failed to downcast right side array");
+        Ok(Arc::new(paste::expr! {[<$OP _bool>]}(&ll, &rr)?))
+    }};
+    // invoke unary operator
+    ($OPERAND:expr, $OP:ident, $DT:ident) => {{
+        let operand = $OPERAND
+            .as_any()
+            .downcast_ref::<$DT>()
+            .expect("compute_op failed to downcast operant array");
+        Ok(Arc::new(paste::expr! {[<$OP _bool>]}(&operand)?))
+    }};
+}
+
 /// Invoke a compute kernel on a data array and a scalar value
 macro_rules! compute_op_scalar {
     ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
@@ -276,6 +379,7 @@ macro_rules! binary_array_op_scalar {
             DataType::Date64 => {
                 compute_op_scalar!($LEFT, $RIGHT, $OP, Date64Array)
             }
+            DataType::Boolean => compute_bool_op_scalar!($LEFT, $RIGHT, $OP, 
BooleanArray),
             other => Err(DataFusionError::Internal(format!(
                 "Data type {:?} not supported for scalar operation '{}' on dyn 
array",
                 other, stringify!($OP)
@@ -320,6 +424,7 @@ macro_rules! binary_array_op {
             DataType::Date64 => {
                 compute_op!($LEFT, $RIGHT, $OP, Date64Array)
             }
+            DataType::Boolean => compute_bool_op!($LEFT, $RIGHT, $OP, 
BooleanArray),
             other => Err(DataFusionError::Internal(format!(
                 "Data type {:?} not supported for binary operation '{}' on dyn 
arrays",
                 other, stringify!($OP)
@@ -822,7 +927,7 @@ mod tests {
 
     use super::*;
     use crate::error::Result;
-    use crate::physical_plan::expressions::col;
+    use crate::physical_plan::expressions::{col, lit};
 
     // Create a binary expression without coercion. Used here when we do not 
want to coerce the expressions
     // to valid types. Usage can result in an execution (after plan) error.
@@ -1372,6 +1477,42 @@ mod tests {
         Ok(())
     }
 
+    // Test `scalar <op> arr` produces expected
+    fn apply_logic_op_scalar_arr(
+        schema: &SchemaRef,
+        scalar: bool,
+        arr: &ArrayRef,
+        op: Operator,
+        expected: &BooleanArray,
+    ) -> Result<()> {
+        let scalar = lit(scalar.into());
+
+        let arithmetic_op = binary_simple(scalar, op, col("a", schema)?);
+        let batch = RecordBatch::try_new(Arc::clone(schema), 
vec![Arc::clone(arr)])?;
+        let result = 
arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
+        assert_eq!(result.as_ref(), expected);
+
+        Ok(())
+    }
+
+    // Test `arr <op> scalar` produces expected
+    fn apply_logic_op_arr_scalar(
+        schema: &SchemaRef,
+        arr: &ArrayRef,
+        scalar: bool,
+        op: Operator,
+        expected: &BooleanArray,
+    ) -> Result<()> {
+        let scalar = lit(scalar.into());
+
+        let arithmetic_op = binary_simple(col("a", schema)?, op, scalar);
+        let batch = RecordBatch::try_new(Arc::clone(schema), 
vec![Arc::clone(arr)])?;
+        let result = 
arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
+        assert_eq!(result.as_ref(), expected);
+
+        Ok(())
+    }
+
     #[test]
     fn and_with_nulls_op() -> Result<()> {
         let schema = Schema::new(vec![
@@ -1462,6 +1603,293 @@ mod tests {
         Ok(())
     }
 
+    /// Returns (schema, a: BooleanArray, b: BooleanArray) with all possible 
inputs
+    ///
+    /// a: [true, true, true,  NULL, NULL, NULL,  false, false, false]
+    /// b: [true, NULL, false, true, NULL, false, true,  NULL,  false]
+    fn bool_test_arrays() -> (SchemaRef, BooleanArray, BooleanArray) {
+        let schema = Schema::new(vec![
+            Field::new("a", DataType::Boolean, false),
+            Field::new("b", DataType::Boolean, false),
+        ]);
+        let a = [
+            Some(true),
+            Some(true),
+            Some(true),
+            None,
+            None,
+            None,
+            Some(false),
+            Some(false),
+            Some(false),
+        ]
+        .iter()
+        .collect();
+        let b = [
+            Some(true),
+            None,
+            Some(false),
+            Some(true),
+            None,
+            Some(false),
+            Some(true),
+            None,
+            Some(false),
+        ]
+        .iter()
+        .collect();
+        (Arc::new(schema), a, b)
+    }
+
+    /// Returns (schema, BooleanArray) with [true, NULL, false]
+    fn scalar_bool_test_array() -> (SchemaRef, ArrayRef) {
+        let schema = Schema::new(vec![Field::new("a", DataType::Boolean, 
false)]);
+        let a: BooleanArray = vec![Some(true), None, 
Some(false)].iter().collect();
+        (Arc::new(schema), Arc::new(a))
+    }
+
+    #[test]
+    fn eq_op_bool() {
+        let (schema, a, b) = bool_test_arrays();
+        let expected = vec![
+            Some(true),
+            None,
+            Some(false),
+            None,
+            None,
+            None,
+            Some(false),
+            None,
+            Some(true),
+        ]
+        .iter()
+        .collect();
+        apply_logic_op(schema, a, b, Operator::Eq, expected).unwrap();
+    }
+
+    #[test]
+    fn eq_op_bool_scalar() {
+        let (schema, a) = scalar_bool_test_array();
+        let expected = [Some(true), None, Some(false)].iter().collect();
+        apply_logic_op_scalar_arr(&schema, true, &a, Operator::Eq, 
&expected).unwrap();
+        apply_logic_op_arr_scalar(&schema, &a, true, Operator::Eq, 
&expected).unwrap();
+
+        let expected = [Some(false), None, Some(true)].iter().collect();
+        apply_logic_op_scalar_arr(&schema, false, &a, Operator::Eq, 
&expected).unwrap();
+        apply_logic_op_arr_scalar(&schema, &a, false, Operator::Eq, 
&expected).unwrap();
+    }
+
+    #[test]
+    fn neq_op_bool() {
+        let (schema, a, b) = bool_test_arrays();
+        let expected = [
+            Some(false),
+            None,
+            Some(true),
+            None,
+            None,
+            None,
+            Some(true),
+            None,
+            Some(false),
+        ]
+        .iter()
+        .collect();
+        apply_logic_op(schema, a, b, Operator::NotEq, expected).unwrap();
+    }
+
+    #[test]
+    fn neq_op_bool_scalar() {
+        let (schema, a) = scalar_bool_test_array();
+        let expected = [Some(false), None, Some(true)].iter().collect();
+        apply_logic_op_scalar_arr(&schema, true, &a, Operator::NotEq, 
&expected).unwrap();
+        apply_logic_op_arr_scalar(&schema, &a, true, Operator::NotEq, 
&expected).unwrap();
+
+        let expected = [Some(true), None, Some(false)].iter().collect();
+        apply_logic_op_scalar_arr(&schema, false, &a, Operator::NotEq, 
&expected)
+            .unwrap();
+        apply_logic_op_arr_scalar(&schema, &a, false, Operator::NotEq, 
&expected)
+            .unwrap();
+    }
+
+    #[test]
+    fn lt_op_bool() {
+        let (schema, a, b) = bool_test_arrays();
+        let expected = [
+            Some(false),
+            None,
+            Some(false),
+            None,
+            None,
+            None,
+            Some(true),
+            None,
+            Some(false),
+        ]
+        .iter()
+        .collect();
+        apply_logic_op(schema, a, b, Operator::Lt, expected).unwrap();
+    }
+
+    #[test]
+    fn lt_op_bool_scalar() {
+        let (schema, a) = scalar_bool_test_array();
+        let expected = [Some(false), None, Some(false)].iter().collect();
+        apply_logic_op_scalar_arr(&schema, true, &a, Operator::Lt, 
&expected).unwrap();
+
+        let expected = [Some(false), None, Some(true)].iter().collect();
+        apply_logic_op_arr_scalar(&schema, &a, true, Operator::Lt, 
&expected).unwrap();
+
+        let expected = [Some(true), None, Some(false)].iter().collect();
+        apply_logic_op_scalar_arr(&schema, false, &a, Operator::Lt, 
&expected).unwrap();
+
+        let expected = [Some(false), None, Some(false)].iter().collect();
+        apply_logic_op_arr_scalar(&schema, &a, false, Operator::Lt, 
&expected).unwrap();
+    }
+
+    #[test]
+    fn lt_eq_op_bool() {
+        let (schema, a, b) = bool_test_arrays();
+        let expected = [
+            Some(true),
+            None,
+            Some(false),
+            None,
+            None,
+            None,
+            Some(true),
+            None,
+            Some(true),
+        ]
+        .iter()
+        .collect();
+        apply_logic_op(schema, a, b, Operator::LtEq, expected).unwrap();
+    }
+
+    #[test]
+    fn lt_eq_op_bool_scalar() {
+        let (schema, a) = scalar_bool_test_array();
+        let expected = [Some(true), None, Some(false)].iter().collect();
+        apply_logic_op_scalar_arr(&schema, true, &a, Operator::LtEq, 
&expected).unwrap();
+
+        let expected = [Some(true), None, Some(true)].iter().collect();
+        apply_logic_op_arr_scalar(&schema, &a, true, Operator::LtEq, 
&expected).unwrap();
+
+        let expected = [Some(true), None, Some(true)].iter().collect();
+        apply_logic_op_scalar_arr(&schema, false, &a, Operator::LtEq, 
&expected).unwrap();
+
+        let expected = [Some(false), None, Some(true)].iter().collect();
+        apply_logic_op_arr_scalar(&schema, &a, false, Operator::LtEq, 
&expected).unwrap();
+    }
+
+    #[test]
+    fn gt_op_bool() {
+        let (schema, a, b) = bool_test_arrays();
+        let expected = [
+            Some(false),
+            None,
+            Some(true),
+            None,
+            None,
+            None,
+            Some(false),
+            None,
+            Some(false),
+        ]
+        .iter()
+        .collect();
+        apply_logic_op(schema, a, b, Operator::Gt, expected).unwrap();
+    }
+
+    #[test]
+    fn gt_op_bool_scalar() {
+        let (schema, a) = scalar_bool_test_array();
+        let expected = [Some(false), None, Some(true)].iter().collect();
+        apply_logic_op_scalar_arr(&schema, true, &a, Operator::Gt, 
&expected).unwrap();
+
+        let expected = [Some(false), None, Some(false)].iter().collect();
+        apply_logic_op_arr_scalar(&schema, &a, true, Operator::Gt, 
&expected).unwrap();
+
+        let expected = [Some(false), None, Some(false)].iter().collect();
+        apply_logic_op_scalar_arr(&schema, false, &a, Operator::Gt, 
&expected).unwrap();
+
+        let expected = [Some(true), None, Some(false)].iter().collect();
+        apply_logic_op_arr_scalar(&schema, &a, false, Operator::Gt, 
&expected).unwrap();
+    }
+
+    #[test]
+    fn gt_eq_op_bool() {
+        let (schema, a, b) = bool_test_arrays();
+        let expected = [
+            Some(true),
+            None,
+            Some(true),
+            None,
+            None,
+            None,
+            Some(false),
+            None,
+            Some(true),
+        ]
+        .iter()
+        .collect();
+        apply_logic_op(schema, a, b, Operator::GtEq, expected).unwrap();
+    }
+
+    #[test]
+    fn gt_eq_op_bool_scalar() {
+        let (schema, a) = scalar_bool_test_array();
+        let expected = [Some(true), None, Some(true)].iter().collect();
+        apply_logic_op_scalar_arr(&schema, true, &a, Operator::GtEq, 
&expected).unwrap();
+
+        let expected = [Some(true), None, Some(false)].iter().collect();
+        apply_logic_op_arr_scalar(&schema, &a, true, Operator::GtEq, 
&expected).unwrap();
+
+        let expected = [Some(false), None, Some(true)].iter().collect();
+        apply_logic_op_scalar_arr(&schema, false, &a, Operator::GtEq, 
&expected).unwrap();
+
+        let expected = [Some(true), None, Some(true)].iter().collect();
+        apply_logic_op_arr_scalar(&schema, &a, false, Operator::GtEq, 
&expected).unwrap();
+    }
+
+    #[test]
+    fn is_distinct_from_op_bool() {
+        let (schema, a, b) = bool_test_arrays();
+        let expected = [
+            Some(false),
+            Some(true),
+            Some(true),
+            Some(true),
+            Some(false),
+            Some(true),
+            Some(true),
+            Some(true),
+            Some(false),
+        ]
+        .iter()
+        .collect();
+        apply_logic_op(schema, a, b, Operator::IsDistinctFrom, 
expected).unwrap();
+    }
+
+    #[test]
+    fn is_not_distinct_from_op_bool() {
+        let (schema, a, b) = bool_test_arrays();
+        let expected = [
+            Some(true),
+            Some(false),
+            Some(false),
+            Some(false),
+            Some(true),
+            Some(false),
+            Some(false),
+            Some(false),
+            Some(true),
+        ]
+        .iter()
+        .collect();
+        apply_logic_op(schema, a, b, Operator::IsNotDistinctFrom, 
expected).unwrap();
+    }
+
     #[test]
     fn test_coersion_error() -> Result<()> {
         let expr =
diff --git a/datafusion/src/physical_plan/expressions/nullif.rs 
b/datafusion/src/physical_plan/expressions/nullif.rs
index 55e7bda..1d91599 100644
--- a/datafusion/src/physical_plan/expressions/nullif.rs
+++ b/datafusion/src/physical_plan/expressions/nullif.rs
@@ -23,7 +23,9 @@ use crate::scalar::ScalarValue;
 use arrow::array::Array;
 use arrow::array::*;
 use arrow::compute::kernels::boolean::nullif;
-use arrow::compute::kernels::comparison::{eq, eq_scalar, eq_utf8, 
eq_utf8_scalar};
+use arrow::compute::kernels::comparison::{
+    eq, eq_bool, eq_bool_scalar, eq_scalar, eq_utf8, eq_utf8_scalar,
+};
 use arrow::datatypes::{DataType, TimeUnit};
 
 /// Invoke a compute kernel on a primitive array and a Boolean Array
diff --git a/datafusion/src/physical_plan/file_format/parquet.rs 
b/datafusion/src/physical_plan/file_format/parquet.rs
index e7980d9..52dc8e9 100644
--- a/datafusion/src/physical_plan/file_format/parquet.rs
+++ b/datafusion/src/physical_plan/file_format/parquet.rs
@@ -714,12 +714,14 @@ mod tests {
     }
 
     #[test]
-    fn row_group_predicate_builder_unsupported_type() -> Result<()> {
+    fn row_group_predicate_builder_null_expr() -> Result<()> {
         use crate::logical_plan::{col, lit};
-        // test row group predicate with unsupported statistics type (boolean)
-        // where a null array is generated for some statistics columns
-        // int > 1 and bool = true => c1_max > 1 and null
-        let expr = col("c1").gt(lit(15)).and(col("c2").eq(lit(true)));
+        // test row group predicate with an unknown (Null) expr
+        //
+        // int > 1 and bool = NULL => c1_max > 1 and null
+        let expr = col("c1")
+            .gt(lit(15))
+            .and(col("c2").eq(lit(ScalarValue::Boolean(None))));
         let schema = Arc::new(Schema::new(vec![
             Field::new("c1", DataType::Int32, false),
             Field::new("c2", DataType::Boolean, false),
diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs
index 8349669..c910e55 100644
--- a/datafusion/tests/sql.rs
+++ b/datafusion/tests/sql.rs
@@ -1052,6 +1052,115 @@ async fn csv_query_having_without_group_by() -> 
Result<()> {
 }
 
 #[tokio::test]
+async fn csv_query_boolean_eq_neq() {
+    let mut ctx = ExecutionContext::new();
+    register_boolean(&mut ctx).await.unwrap();
+    // verify the plumbing is all hooked up for eq and neq
+    let sql = "SELECT a, b, a = b as eq, b = true as eq_scalar, a != b as neq, 
a != true as neq_scalar FROM t1";
+    let actual = execute_to_batches(&mut ctx, sql).await;
+
+    let expected = vec![
+        "+-------+-------+-------+-----------+-------+------------+",
+        "| a     | b     | eq    | eq_scalar | neq   | neq_scalar |",
+        "+-------+-------+-------+-----------+-------+------------+",
+        "| true  | true  | true  | true      | false | false      |",
+        "| true  |       |       |           |       | false      |",
+        "| true  | false | false | false     | true  | false      |",
+        "|       | true  |       | true      |       |            |",
+        "|       |       |       |           |       |            |",
+        "|       | false |       | false     |       |            |",
+        "| false | true  | false | true      | true  | true       |",
+        "| false |       |       |           |       | true       |",
+        "| false | false | true  | false     | false | true       |",
+        "+-------+-------+-------+-----------+-------+------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+}
+
+#[tokio::test]
+async fn csv_query_boolean_lt_lt_eq() {
+    let mut ctx = ExecutionContext::new();
+    register_boolean(&mut ctx).await.unwrap();
+    // verify the plumbing is all hooked up for < and <=
+    let sql = "SELECT a, b, a < b as lt, b = true as lt_scalar, a <= b as 
lt_eq, a <= true as lt_eq_scalar FROM t1";
+    let actual = execute_to_batches(&mut ctx, sql).await;
+
+    let expected = vec![
+        "+-------+-------+-------+-----------+-------+--------------+",
+        "| a     | b     | lt    | lt_scalar | lt_eq | lt_eq_scalar |",
+        "+-------+-------+-------+-----------+-------+--------------+",
+        "| true  | true  | false | true      | true  | true         |",
+        "| true  |       |       |           |       | true         |",
+        "| true  | false | false | false     | false | true         |",
+        "|       | true  |       | true      |       |              |",
+        "|       |       |       |           |       |              |",
+        "|       | false |       | false     |       |              |",
+        "| false | true  | true  | true      | true  | true         |",
+        "| false |       |       |           |       | true         |",
+        "| false | false | false | false     | true  | true         |",
+        "+-------+-------+-------+-----------+-------+--------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+}
+
+#[tokio::test]
+async fn csv_query_boolean_gt_gt_eq() {
+    let mut ctx = ExecutionContext::new();
+    register_boolean(&mut ctx).await.unwrap();
+    // verify the plumbing is all hooked up for > and >=
+    let sql = "SELECT a, b, a > b as gt, b = true as gt_scalar, a >= b as 
gt_eq, a >= true as gt_eq_scalar FROM t1";
+    let actual = execute_to_batches(&mut ctx, sql).await;
+
+    let expected = vec![
+        "+-------+-------+-------+-----------+-------+--------------+",
+        "| a     | b     | gt    | gt_scalar | gt_eq | gt_eq_scalar |",
+        "+-------+-------+-------+-----------+-------+--------------+",
+        "| true  | true  | false | true      | true  | true         |",
+        "| true  |       |       |           |       | true         |",
+        "| true  | false | true  | false     | true  | true         |",
+        "|       | true  |       | true      |       |              |",
+        "|       |       |       |           |       |              |",
+        "|       | false |       | false     |       |              |",
+        "| false | true  | false | true      | false | false        |",
+        "| false |       |       |           |       | false        |",
+        "| false | false | false | false     | true  | false        |",
+        "+-------+-------+-------+-----------+-------+--------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+}
+
+#[tokio::test]
+async fn csv_query_boolean_distinct_from() {
+    let mut ctx = ExecutionContext::new();
+    register_boolean(&mut ctx).await.unwrap();
+    // verify the plumbing is all hooked up for is distinct from and is not 
distinct from
+    let sql = "SELECT a, b, \
+               a is distinct from b as df, \
+               b is distinct from true as df_scalar, \
+               a is not distinct from b as ndf, \
+               a is not distinct from true as ndf_scalar \
+               FROM t1";
+    let actual = execute_to_batches(&mut ctx, sql).await;
+
+    let expected = vec![
+        "+-------+-------+-------+-----------+-------+------------+",
+        "| a     | b     | df    | df_scalar | ndf   | ndf_scalar |",
+        "+-------+-------+-------+-----------+-------+------------+",
+        "| true  | true  | false | false     | true  | true       |",
+        "| true  |       | true  | true      | false | true       |",
+        "| true  | false | true  | true      | false | true       |",
+        "|       | true  | true  | false     | false | false      |",
+        "|       |       | false | true      | true  | false      |",
+        "|       | false | true  | true      | false | false      |",
+        "| false | true  | true  | false     | false | false      |",
+        "| false |       | true  | true      | false | false      |",
+        "| false | false | false | true      | true  | false      |",
+        "+-------+-------+-------+-----------+-------+------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+}
+
+#[tokio::test]
 async fn csv_query_avg_sqrt() -> Result<()> {
     let mut ctx = create_ctx()?;
     register_aggregate_csv(&mut ctx).await?;
@@ -3458,6 +3567,42 @@ async fn register_aggregate_csv_by_sql(ctx: &mut 
ExecutionContext) {
     );
 }
 
+/// Create table "t1" with two boolean columns "a" and "b"
+async fn register_boolean(ctx: &mut ExecutionContext) -> Result<()> {
+    let a: BooleanArray = [
+        Some(true),
+        Some(true),
+        Some(true),
+        None,
+        None,
+        None,
+        Some(false),
+        Some(false),
+        Some(false),
+    ]
+    .iter()
+    .collect();
+    let b: BooleanArray = [
+        Some(true),
+        None,
+        Some(false),
+        Some(true),
+        None,
+        Some(false),
+        Some(true),
+        None,
+        Some(false),
+    ]
+    .iter()
+    .collect();
+
+    let data =
+        RecordBatch::try_from_iter([("a", Arc::new(a) as _), ("b", Arc::new(b) 
as _)])?;
+    let table = MemTable::try_new(data.schema(), vec![vec![data]])?;
+    ctx.register_table("t1", Arc::new(table))?;
+    Ok(())
+}
+
 async fn register_aggregate_csv(ctx: &mut ExecutionContext) -> Result<()> {
     let testdata = datafusion::test_util::arrow_test_data();
     let schema = test_util::aggr_test_schema();
@@ -4539,6 +4684,8 @@ macro_rules! test_expression {
 async fn test_boolean_expressions() -> Result<()> {
     test_expression!("true", "true");
     test_expression!("false", "false");
+    test_expression!("false = false", "true");
+    test_expression!("true = false", "false");
     Ok(())
 }
 

Reply via email to