This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new d7b221213 feat: add optimization rules for bitwise operations (#5423)
d7b221213 is described below

commit d7b221213c767e36c99bc17968eae1323e21b7c4
Author: Igor Izvekov <[email protected]>
AuthorDate: Tue Mar 7 23:07:02 2023 +0300

    feat: add optimization rules for bitwise operations (#5423)
    
    * feat: add optimization rules for bitwise operations
    
    * fix: cargo fmt
    
    * fix: comments, determining data type of an int literal, additional tests
    
    * fix: add "get_data_type" to impl SimplifyInfo for MyInfo
    
    * fix: add behavior of the function "get_data_type"
    
    * add information about "get_data_type" to the docs
    
    * fix: doc tests for "expr_simplifier.rs"
    
    * fix: delete_xor_in_complex_expr
    
    * fix: some bitwise optimization rules, that can work even if an expr is 
nullable
    
    * fix: prefer ScalarValue methods over "new_int_by_expr_data_type"
    
    * fix: ScalarValue::new_negative_one
    
    * fix: get_data_type
    
    * fix: cargo clippy and fmt style
---
 datafusion/common/src/scalar.rs                    |  20 +-
 datafusion/core/tests/simplification.rs            |   7 +
 datafusion/expr/src/expr.rs                        |  25 +
 datafusion/expr/src/expr_fn.rs                     |  45 ++
 .../optimizer/src/simplify_expressions/context.rs  |  18 +
 .../src/simplify_expressions/expr_simplifier.rs    | 827 ++++++++++++++++++++-
 .../optimizer/src/simplify_expressions/utils.rs    |  62 +-
 7 files changed, 997 insertions(+), 7 deletions(-)

diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs
index 212328121..3f2f4bb39 100644
--- a/datafusion/common/src/scalar.rs
+++ b/datafusion/common/src/scalar.rs
@@ -1019,7 +1019,7 @@ impl ScalarValue {
         Self::List(scalars, Box::new(Field::new("item", child_type, true)))
     }
 
-    // Create a zero value in the given type.
+    /// Create a zero value in the given type.
     pub fn new_zero(datatype: &DataType) -> Result<ScalarValue> {
         assert!(datatype.is_primitive());
         Ok(match datatype {
@@ -1042,6 +1042,24 @@ impl ScalarValue {
         })
     }
 
+    /// Create a negative one value in the given type.
+    pub fn new_negative_one(datatype: &DataType) -> Result<ScalarValue> {
+        assert!(datatype.is_primitive());
+        Ok(match datatype {
+            DataType::Int8 | DataType::UInt8 => ScalarValue::Int8(Some(-1)),
+            DataType::Int16 | DataType::UInt16 => ScalarValue::Int16(Some(-1)),
+            DataType::Int32 | DataType::UInt32 => ScalarValue::Int32(Some(-1)),
+            DataType::Int64 | DataType::UInt64 => ScalarValue::Int64(Some(-1)),
+            DataType::Float32 => ScalarValue::Float32(Some(-1.0)),
+            DataType::Float64 => ScalarValue::Float64(Some(-1.0)),
+            _ => {
+                return Err(DataFusionError::NotImplemented(format!(
+                    "Can't create a negative one scalar from data_type 
\"{datatype:?}\""
+                )));
+            }
+        })
+    }
+
     /// Getter for the `DataType` of the value
     pub fn get_datatype(&self) -> DataType {
         match self {
diff --git a/datafusion/core/tests/simplification.rs 
b/datafusion/core/tests/simplification.rs
index 6e74fc0d9..f6b944b50 100644
--- a/datafusion/core/tests/simplification.rs
+++ b/datafusion/core/tests/simplification.rs
@@ -49,6 +49,13 @@ impl SimplifyInfo for MyInfo {
     fn execution_props(&self) -> &ExecutionProps {
         &self.execution_props
     }
+
+    fn get_data_type(&self, expr: &Expr) -> Result<DataType> {
+        match expr.get_type(&self.schema) {
+            Ok(expr_data_type) => Ok(expr_data_type),
+            Err(e) => Err(e),
+        }
+    }
 }
 
 impl From<DFSchema> for MyInfo {
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 8b6c39043..1530513e9 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -639,6 +639,31 @@ impl Expr {
         binary_expr(self, Operator::Or, other)
     }
 
+    /// Return `self & other`
+    pub fn bitwise_and(self, other: Expr) -> Expr {
+        binary_expr(self, Operator::BitwiseAnd, other)
+    }
+
+    /// Return `self | other`
+    pub fn bitwise_or(self, other: Expr) -> Expr {
+        binary_expr(self, Operator::BitwiseOr, other)
+    }
+
+    /// Return `self ^ other`
+    pub fn bitwise_xor(self, other: Expr) -> Expr {
+        binary_expr(self, Operator::BitwiseXor, other)
+    }
+
+    /// Return `self >> other`
+    pub fn bitwise_shift_right(self, other: Expr) -> Expr {
+        binary_expr(self, Operator::BitwiseShiftRight, other)
+    }
+
+    /// Return `self << other`
+    pub fn bitwise_shift_left(self, other: Expr) -> Expr {
+        binary_expr(self, Operator::BitwiseShiftLeft, other)
+    }
+
     /// Return `!self`
     #[allow(clippy::should_implement_trait)]
     pub fn not(self) -> Expr {
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index 325fac57c..6465ca80b 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -112,6 +112,51 @@ pub fn count(expr: Expr) -> Expr {
     ))
 }
 
+/// Return a new expression with bitwise AND
+pub fn bitwise_and(left: Expr, right: Expr) -> Expr {
+    Expr::BinaryExpr(BinaryExpr::new(
+        Box::new(left),
+        Operator::BitwiseAnd,
+        Box::new(right),
+    ))
+}
+
+/// Return a new expression with bitwise OR
+pub fn bitwise_or(left: Expr, right: Expr) -> Expr {
+    Expr::BinaryExpr(BinaryExpr::new(
+        Box::new(left),
+        Operator::BitwiseOr,
+        Box::new(right),
+    ))
+}
+
+/// Return a new expression with bitwise XOR
+pub fn bitwise_xor(left: Expr, right: Expr) -> Expr {
+    Expr::BinaryExpr(BinaryExpr::new(
+        Box::new(left),
+        Operator::BitwiseXor,
+        Box::new(right),
+    ))
+}
+
+/// Return a new expression with bitwise SHIFT RIGHT
+pub fn bitwise_shift_right(left: Expr, right: Expr) -> Expr {
+    Expr::BinaryExpr(BinaryExpr::new(
+        Box::new(left),
+        Operator::BitwiseShiftRight,
+        Box::new(right),
+    ))
+}
+
+/// Return a new expression with bitwise SHIFT LEFT
+pub fn bitwise_shift_left(left: Expr, right: Expr) -> Expr {
+    Expr::BinaryExpr(BinaryExpr::new(
+        Box::new(left),
+        Operator::BitwiseShiftLeft,
+        Box::new(right),
+    ))
+}
+
 /// Create an expression to represent the count(distinct) aggregate function
 pub fn count_distinct(expr: Expr) -> Expr {
     Expr::AggregateFunction(AggregateFunction::new(
diff --git a/datafusion/optimizer/src/simplify_expressions/context.rs 
b/datafusion/optimizer/src/simplify_expressions/context.rs
index 379a803f4..b6e3f07b2 100644
--- a/datafusion/optimizer/src/simplify_expressions/context.rs
+++ b/datafusion/optimizer/src/simplify_expressions/context.rs
@@ -38,6 +38,9 @@ pub trait SimplifyInfo {
 
     /// Returns details needed for partial expression evaluation
     fn execution_props(&self) -> &ExecutionProps;
+
+    /// Returns data type of this expr needed for determining optimized int 
type of a value
+    fn get_data_type(&self, expr: &Expr) -> Result<DataType>;
 }
 
 /// Provides simplification information based on DFSchema and
@@ -123,6 +126,21 @@ impl<'a> SimplifyInfo for SimplifyContext<'a> {
             })
     }
 
+    /// Returns data type of this expr needed for determining optimized int 
type of a value
+    fn get_data_type(&self, expr: &Expr) -> Result<DataType> {
+        if self.schemas.len() == 1 {
+            match expr.get_type(&self.schemas[0]) {
+                Ok(expr_data_type) => Ok(expr_data_type),
+                Err(e) => Err(e),
+            }
+        } else {
+            Err(DataFusionError::Internal(
+                "The expr has more than one schema, could not determine data 
type"
+                    .to_string(),
+            ))
+        }
+    }
+
     fn execution_props(&self) -> &ExecutionProps {
         self.props
     }
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs 
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index 5421c43da..220d532b0 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -70,6 +70,7 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
     /// `b > 2`
     ///
     /// ```
+    /// use arrow::datatypes::DataType;
     /// use datafusion_expr::{col, lit, Expr};
     /// use datafusion_common::Result;
     /// use datafusion_physical_expr::execution_props::ExecutionProps;
@@ -92,6 +93,9 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
     ///   fn execution_props(&self) -> &ExecutionProps {
     ///     &self.execution_props
     ///   }
+    ///   fn get_data_type(&self, expr: &Expr) -> Result<DataType> {
+    ///     Ok(DataType::Int32)
+    ///   }
     /// }
     ///
     /// // Create the simplifier
@@ -337,7 +341,8 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, 
S> {
     /// rewrite the expression simplifying any constant expressions
     fn mutate(&mut self, expr: Expr) -> Result<Expr> {
         use datafusion_expr::Operator::{
-            And, Divide, Eq, Modulo, Multiply, NotEq, Or, RegexIMatch, 
RegexMatch,
+            And, BitwiseAnd, BitwiseOr, BitwiseShiftLeft, BitwiseShiftRight, 
BitwiseXor,
+            Divide, Eq, Modulo, Multiply, NotEq, Or, RegexIMatch, RegexMatch,
             RegexNotIMatch, RegexNotMatch,
         };
 
@@ -700,6 +705,298 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, 
S> {
                 return 
Err(DataFusionError::ArrowError(ArrowError::DivideByZero));
             }
 
+            //
+            // Rules for BitwiseAnd
+            //
+
+            // A & null -> null
+            Expr::BinaryExpr(BinaryExpr {
+                left: _,
+                op: BitwiseAnd,
+                right,
+            }) if is_null(&right) => *right,
+
+            // null & A -> null
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: BitwiseAnd,
+                right: _,
+            }) if is_null(&left) => *left,
+
+            // A & 0 -> 0 (if A not nullable)
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: BitwiseAnd,
+                right,
+            }) if !info.nullable(&left)? && is_zero(&right) => *right,
+
+            // 0 & A -> 0 (if A not nullable)
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: BitwiseAnd,
+                right,
+            }) if !info.nullable(&right)? && is_zero(&left) => *left,
+
+            // !A & A -> 0 (if A not nullable)
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: BitwiseAnd,
+                right,
+            }) if is_negative_of(&left, &right) && !info.nullable(&right)? => {
+                
Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?)
+            }
+
+            // A & !A -> 0 (if A not nullable)
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: BitwiseAnd,
+                right,
+            }) if is_negative_of(&right, &left) && !info.nullable(&left)? => {
+                
Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?)
+            }
+
+            // (..A..) & A --> (..A..)
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: BitwiseAnd,
+                right,
+            }) if expr_contains(&left, &right, BitwiseAnd) => *left,
+
+            // A & (..A..) --> (..A..)
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: BitwiseAnd,
+                right,
+            }) if expr_contains(&right, &left, BitwiseAnd) => *right,
+
+            // A & (A | B) --> A (if B not null)
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: BitwiseAnd,
+                right,
+            }) if !info.nullable(&right)? && is_op_with(BitwiseOr, &right, 
&left) => {
+                *left
+            }
+
+            // (A | B) & A --> A (if B not null)
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: BitwiseAnd,
+                right,
+            }) if !info.nullable(&left)? && is_op_with(BitwiseOr, &left, 
&right) => {
+                *right
+            }
+
+            //
+            // Rules for BitwiseOr
+            //
+
+            // A | null -> null
+            Expr::BinaryExpr(BinaryExpr {
+                left: _,
+                op: BitwiseOr,
+                right,
+            }) if is_null(&right) => *right,
+
+            // null | A -> null
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: BitwiseOr,
+                right: _,
+            }) if is_null(&left) => *left,
+
+            // A | 0 -> A (even if A is null)
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: BitwiseOr,
+                right,
+            }) if is_zero(&right) => *left,
+
+            // 0 | A -> A (even if A is null)
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: BitwiseOr,
+                right,
+            }) if is_zero(&left) => *right,
+
+            // !A | A -> -1 (if A not nullable)
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: BitwiseOr,
+                right,
+            }) if is_negative_of(&left, &right) && !info.nullable(&right)? => {
+                
Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?)
+            }
+
+            // A | !A -> -1 (if A not nullable)
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: BitwiseOr,
+                right,
+            }) if is_negative_of(&right, &left) && !info.nullable(&left)? => {
+                
Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?)
+            }
+
+            // (..A..) | A --> (..A..)
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: BitwiseOr,
+                right,
+            }) if expr_contains(&left, &right, BitwiseOr) => *left,
+
+            // A | (..A..) --> (..A..)
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: BitwiseOr,
+                right,
+            }) if expr_contains(&right, &left, BitwiseOr) => *right,
+
+            // A | (A & B) --> A (if B not null)
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: BitwiseOr,
+                right,
+            }) if !info.nullable(&right)? && is_op_with(BitwiseAnd, &right, 
&left) => {
+                *left
+            }
+
+            // (A & B) | A --> A (if B not null)
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: BitwiseOr,
+                right,
+            }) if !info.nullable(&left)? && is_op_with(BitwiseAnd, &left, 
&right) => {
+                *right
+            }
+
+            //
+            // Rules for BitwiseXor
+            //
+
+            // A ^ null -> null
+            Expr::BinaryExpr(BinaryExpr {
+                left: _,
+                op: BitwiseXor,
+                right,
+            }) if is_null(&right) => *right,
+
+            // null ^ A -> null
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: BitwiseXor,
+                right: _,
+            }) if is_null(&left) => *left,
+
+            // A ^ 0 -> A (if A not nullable)
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: BitwiseXor,
+                right,
+            }) if !info.nullable(&left)? && is_zero(&right) => *left,
+
+            // 0 ^ A -> A (if A not nullable)
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: BitwiseXor,
+                right,
+            }) if !info.nullable(&right)? && is_zero(&left) => *right,
+
+            // !A ^ A -> -1 (if A not nullable)
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: BitwiseXor,
+                right,
+            }) if is_negative_of(&left, &right) && !info.nullable(&right)? => {
+                
Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?)
+            }
+
+            // A ^ !A -> -1 (if A not nullable)
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: BitwiseXor,
+                right,
+            }) if is_negative_of(&right, &left) && !info.nullable(&left)? => {
+                
Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?)
+            }
+
+            // (..A..) ^ A --> (the expression without A, if number of A is 
odd, otherwise one A)
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: BitwiseXor,
+                right,
+            }) if expr_contains(&left, &right, BitwiseXor) => {
+                let expr = delete_xor_in_complex_expr(&left, &right, false);
+                if expr == *right {
+                    
Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&right)?)?)
+                } else {
+                    expr
+                }
+            }
+
+            // A ^ (..A..) --> (the expression without A, if number of A is 
odd, otherwise one A)
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: BitwiseXor,
+                right,
+            }) if expr_contains(&right, &left, BitwiseXor) => {
+                let expr = delete_xor_in_complex_expr(&right, &left, true);
+                if expr == *left {
+                    
Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?)
+                } else {
+                    expr
+                }
+            }
+
+            //
+            // Rules for BitwiseShiftRight
+            //
+
+            // A >> null -> null
+            Expr::BinaryExpr(BinaryExpr {
+                left: _,
+                op: BitwiseShiftRight,
+                right,
+            }) if is_null(&right) => *right,
+
+            // null >> A -> null
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: BitwiseShiftRight,
+                right: _,
+            }) if is_null(&left) => *left,
+
+            // A >> 0 -> A (even if A is null)
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: BitwiseShiftRight,
+                right,
+            }) if is_zero(&right) => *left,
+
+            //
+            // Rules for BitwiseShiftRight
+            //
+
+            // A << null -> null
+            Expr::BinaryExpr(BinaryExpr {
+                left: _,
+                op: BitwiseShiftLeft,
+                right,
+            }) if is_null(&right) => *right,
+
+            // null << A -> null
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: BitwiseShiftLeft,
+                right: _,
+            }) if is_null(&left) => *left,
+
+            // A << 0 -> A (even if A is null)
+            Expr::BinaryExpr(BinaryExpr {
+                left,
+                op: BitwiseShiftLeft,
+                right,
+            }) if is_zero(&right) => *left,
+
             //
             // Rules for Not
             //
@@ -1346,6 +1643,522 @@ mod tests {
         assert_eq!(simplify(expr), expected);
     }
 
+    #[test]
+    fn test_simplify_bitwise_xor_by_null() {
+        let null = Expr::Literal(ScalarValue::Null);
+        // A ^ null --> null
+        {
+            let expr = binary_expr(col("c2"), Operator::BitwiseXor, 
null.clone());
+            assert_eq!(simplify(expr), null);
+        }
+        // null ^ A --> null
+        {
+            let expr = binary_expr(null.clone(), Operator::BitwiseXor, 
col("c2"));
+            assert_eq!(simplify(expr), null);
+        }
+    }
+
+    #[test]
+    fn test_simplify_bitwise_shift_right_by_null() {
+        let null = Expr::Literal(ScalarValue::Null);
+        // A >> null --> null
+        {
+            let expr = binary_expr(col("c2"), Operator::BitwiseShiftRight, 
null.clone());
+            assert_eq!(simplify(expr), null);
+        }
+        // null >> A --> null
+        {
+            let expr = binary_expr(null.clone(), Operator::BitwiseShiftRight, 
col("c2"));
+            assert_eq!(simplify(expr), null);
+        }
+    }
+
+    #[test]
+    fn test_simplify_bitwise_shift_left_by_null() {
+        let null = Expr::Literal(ScalarValue::Null);
+        // A << null --> null
+        {
+            let expr = binary_expr(col("c2"), Operator::BitwiseShiftLeft, 
null.clone());
+            assert_eq!(simplify(expr), null);
+        }
+        // null << A --> null
+        {
+            let expr = binary_expr(null.clone(), Operator::BitwiseShiftLeft, 
col("c2"));
+            assert_eq!(simplify(expr), null);
+        }
+    }
+
+    #[test]
+    fn test_simplify_bitwise_and_by_zero() {
+        // A & 0 --> 0
+        {
+            let expr = binary_expr(col("c2_non_null"), Operator::BitwiseAnd, 
lit(0));
+            assert_eq!(simplify(expr), lit(0));
+        }
+        // 0 & A --> 0
+        {
+            let expr = binary_expr(lit(0), Operator::BitwiseAnd, 
col("c2_non_null"));
+            assert_eq!(simplify(expr), lit(0));
+        }
+    }
+
+    #[test]
+    fn test_simplify_bitwise_or_by_zero() {
+        // A | 0 --> A
+        {
+            let expr = binary_expr(col("c2_non_null"), Operator::BitwiseOr, 
lit(0));
+            assert_eq!(simplify(expr), col("c2_non_null"));
+        }
+        // 0 | A --> A
+        {
+            let expr = binary_expr(lit(0), Operator::BitwiseOr, 
col("c2_non_null"));
+            assert_eq!(simplify(expr), col("c2_non_null"));
+        }
+    }
+
+    #[test]
+    fn test_simplify_bitwise_xor_by_zero() {
+        // A ^ 0 --> A
+        {
+            let expr = binary_expr(col("c2_non_null"), Operator::BitwiseXor, 
lit(0));
+            assert_eq!(simplify(expr), col("c2_non_null"));
+        }
+        // 0 ^ A --> A
+        {
+            let expr = binary_expr(lit(0), Operator::BitwiseXor, 
col("c2_non_null"));
+            assert_eq!(simplify(expr), col("c2_non_null"));
+        }
+    }
+
+    #[test]
+    fn test_simplify_bitwise_bitwise_shift_right_by_zero() {
+        // A >> 0 --> A
+        {
+            let expr =
+                binary_expr(col("c2_non_null"), Operator::BitwiseShiftRight, 
lit(0));
+            assert_eq!(simplify(expr), col("c2_non_null"));
+        }
+    }
+
+    #[test]
+    fn test_simplify_bitwise_bitwise_shift_left_by_zero() {
+        // A << 0 --> A
+        {
+            let expr =
+                binary_expr(col("c2_non_null"), Operator::BitwiseShiftLeft, 
lit(0));
+            assert_eq!(simplify(expr), col("c2_non_null"));
+        }
+    }
+
+    #[test]
+    fn test_simplify_bitwise_and_by_null() {
+        let null = Expr::Literal(ScalarValue::Null);
+        // A & null --> null
+        {
+            let expr = binary_expr(col("c2"), Operator::BitwiseAnd, 
null.clone());
+            assert_eq!(simplify(expr), null);
+        }
+        // null & A --> null
+        {
+            let expr = binary_expr(null.clone(), Operator::BitwiseAnd, 
col("c2"));
+            assert_eq!(simplify(expr), null);
+        }
+    }
+
+    #[test]
+    fn test_simplify_composed_bitwise_and() {
+        // ((c2 > 5) & (c1 < 6)) & (c2 > 5) --> (c2 > 5) & (c1 < 6)
+
+        let expr = binary_expr(
+            binary_expr(
+                col("c2").gt(lit(5)),
+                Operator::BitwiseAnd,
+                col("c1").lt(lit(6)),
+            ),
+            Operator::BitwiseAnd,
+            col("c2").gt(lit(5)),
+        );
+        let expected = binary_expr(
+            col("c2").gt(lit(5)),
+            Operator::BitwiseAnd,
+            col("c1").lt(lit(6)),
+        );
+
+        assert_eq!(simplify(expr), expected);
+
+        // (c2 > 5) & ((c2 > 5) & (c1 < 6)) --> (c2 > 5) & (c1 < 6)
+
+        let expr = binary_expr(
+            col("c2").gt(lit(5)),
+            Operator::BitwiseAnd,
+            binary_expr(
+                col("c2").gt(lit(5)),
+                Operator::BitwiseAnd,
+                col("c1").lt(lit(6)),
+            ),
+        );
+        let expected = binary_expr(
+            col("c2").gt(lit(5)),
+            Operator::BitwiseAnd,
+            col("c1").lt(lit(6)),
+        );
+        assert_eq!(simplify(expr), expected);
+    }
+
+    #[test]
+    fn test_simplify_composed_bitwise_or() {
+        // ((c2 > 5) | (c1 < 6)) | (c2 > 5) --> (c2 > 5) | (c1 < 6)
+
+        let expr = binary_expr(
+            binary_expr(
+                col("c2").gt(lit(5)),
+                Operator::BitwiseOr,
+                col("c1").lt(lit(6)),
+            ),
+            Operator::BitwiseOr,
+            col("c2").gt(lit(5)),
+        );
+        let expected = binary_expr(
+            col("c2").gt(lit(5)),
+            Operator::BitwiseOr,
+            col("c1").lt(lit(6)),
+        );
+
+        assert_eq!(simplify(expr), expected);
+
+        // (c2 > 5) | ((c2 > 5) | (c1 < 6)) --> (c2 > 5) | (c1 < 6)
+
+        let expr = binary_expr(
+            col("c2").gt(lit(5)),
+            Operator::BitwiseOr,
+            binary_expr(
+                col("c2").gt(lit(5)),
+                Operator::BitwiseOr,
+                col("c1").lt(lit(6)),
+            ),
+        );
+        let expected = binary_expr(
+            col("c2").gt(lit(5)),
+            Operator::BitwiseOr,
+            col("c1").lt(lit(6)),
+        );
+
+        assert_eq!(simplify(expr), expected);
+    }
+
+    #[test]
+    fn test_simplify_composed_bitwise_xor() {
+        // with an even number of the column "c2"
+        // c2 ^ ((c2 ^ (c2 | c1)) ^ (c1 & c2)) --> (c2 | c1) ^ (c1 & c2)
+
+        let expr = binary_expr(
+            col("c2"),
+            Operator::BitwiseXor,
+            binary_expr(
+                binary_expr(
+                    col("c2"),
+                    Operator::BitwiseXor,
+                    binary_expr(col("c2"), Operator::BitwiseOr, col("c1")),
+                ),
+                Operator::BitwiseXor,
+                binary_expr(col("c1"), Operator::BitwiseAnd, col("c2")),
+            ),
+        );
+
+        let expected = binary_expr(
+            binary_expr(col("c2"), Operator::BitwiseOr, col("c1")),
+            Operator::BitwiseXor,
+            binary_expr(col("c1"), Operator::BitwiseAnd, col("c2")),
+        );
+
+        assert_eq!(simplify(expr), expected);
+
+        // with an odd number of the column "c2"
+        // c2 ^ (c2 ^ (c2 | c1)) ^ ((c1 & c2) ^ c2) --> c2 ^ ((c2 | c1) ^ (c1 
& c2))
+
+        let expr = binary_expr(
+            col("c2"),
+            Operator::BitwiseXor,
+            binary_expr(
+                binary_expr(
+                    col("c2"),
+                    Operator::BitwiseXor,
+                    binary_expr(col("c2"), Operator::BitwiseOr, col("c1")),
+                ),
+                Operator::BitwiseXor,
+                binary_expr(
+                    binary_expr(col("c1"), Operator::BitwiseAnd, col("c2")),
+                    Operator::BitwiseXor,
+                    col("c2"),
+                ),
+            ),
+        );
+
+        let expected = binary_expr(
+            col("c2"),
+            Operator::BitwiseXor,
+            binary_expr(
+                binary_expr(col("c2"), Operator::BitwiseOr, col("c1")),
+                Operator::BitwiseXor,
+                binary_expr(col("c1"), Operator::BitwiseAnd, col("c2")),
+            ),
+        );
+
+        assert_eq!(simplify(expr), expected);
+
+        // with an even number of the column "c2"
+        // ((c2 ^ (c2 | c1)) ^ (c1 & c2)) ^ c2 --> (c2 | c1) ^ (c1 & c2)
+
+        let expr = binary_expr(
+            binary_expr(
+                binary_expr(
+                    col("c2"),
+                    Operator::BitwiseXor,
+                    binary_expr(col("c2"), Operator::BitwiseOr, col("c1")),
+                ),
+                Operator::BitwiseXor,
+                binary_expr(col("c1"), Operator::BitwiseAnd, col("c2")),
+            ),
+            Operator::BitwiseXor,
+            col("c2"),
+        );
+
+        let expected = binary_expr(
+            binary_expr(col("c2"), Operator::BitwiseOr, col("c1")),
+            Operator::BitwiseXor,
+            binary_expr(col("c1"), Operator::BitwiseAnd, col("c2")),
+        );
+
+        assert_eq!(simplify(expr), expected);
+
+        // with an odd number of the column "c2"
+        // (c2 ^ (c2 | c1)) ^ ((c1 & c2) ^ c2) ^ c2 --> ((c2 | c1) ^ (c1 & 
c2)) ^ c2
+
+        let expr = binary_expr(
+            binary_expr(
+                binary_expr(
+                    col("c2"),
+                    Operator::BitwiseXor,
+                    binary_expr(col("c2"), Operator::BitwiseOr, col("c1")),
+                ),
+                Operator::BitwiseXor,
+                binary_expr(
+                    binary_expr(col("c1"), Operator::BitwiseAnd, col("c2")),
+                    Operator::BitwiseXor,
+                    col("c2"),
+                ),
+            ),
+            Operator::BitwiseXor,
+            col("c2"),
+        );
+
+        let expected = binary_expr(
+            binary_expr(
+                binary_expr(col("c2"), Operator::BitwiseOr, col("c1")),
+                Operator::BitwiseXor,
+                binary_expr(col("c1"), Operator::BitwiseAnd, col("c2")),
+            ),
+            Operator::BitwiseXor,
+            col("c2"),
+        );
+
+        assert_eq!(simplify(expr), expected);
+    }
+
+    #[test]
+    fn test_simplify_negated_bitwise_and() {
+        // !c4 & c4 --> 0
+        let expr = binary_expr(
+            Expr::Negative(Box::new(col("c4_non_null"))),
+            Operator::BitwiseAnd,
+            col("c4_non_null"),
+        );
+        let expected = Expr::Literal(ScalarValue::UInt32(Some(0)));
+
+        assert_eq!(simplify(expr), expected);
+        // c4 & !c4 --> 0
+        let expr = binary_expr(
+            col("c4_non_null"),
+            Operator::BitwiseAnd,
+            Expr::Negative(Box::new(col("c4_non_null"))),
+        );
+        let expected = Expr::Literal(ScalarValue::UInt32(Some(0)));
+
+        assert_eq!(simplify(expr), expected);
+
+        // !c3 & c3 --> 0
+        let expr = binary_expr(
+            Expr::Negative(Box::new(col("c3_non_null"))),
+            Operator::BitwiseAnd,
+            col("c3_non_null"),
+        );
+        let expected = Expr::Literal(ScalarValue::Int64(Some(0)));
+
+        assert_eq!(simplify(expr), expected);
+        // c3 & !c3 --> 0
+        let expr = binary_expr(
+            col("c3_non_null"),
+            Operator::BitwiseAnd,
+            Expr::Negative(Box::new(col("c3_non_null"))),
+        );
+        let expected = Expr::Literal(ScalarValue::Int64(Some(0)));
+
+        assert_eq!(simplify(expr), expected);
+    }
+
+    #[test]
+    fn test_simplify_negated_bitwise_or() {
+        // !c4 | c4 --> -1
+        let expr = binary_expr(
+            Expr::Negative(Box::new(col("c4_non_null"))),
+            Operator::BitwiseOr,
+            col("c4_non_null"),
+        );
+        let expected = Expr::Literal(ScalarValue::Int32(Some(-1)));
+
+        assert_eq!(simplify(expr), expected);
+
+        // c4 | !c4 --> -1
+        let expr = binary_expr(
+            col("c4_non_null"),
+            Operator::BitwiseOr,
+            Expr::Negative(Box::new(col("c4_non_null"))),
+        );
+        let expected = Expr::Literal(ScalarValue::Int32(Some(-1)));
+
+        assert_eq!(simplify(expr), expected);
+
+        // !c3 | c3 --> -1
+        let expr = binary_expr(
+            Expr::Negative(Box::new(col("c3_non_null"))),
+            Operator::BitwiseOr,
+            col("c3_non_null"),
+        );
+        let expected = Expr::Literal(ScalarValue::Int64(Some(-1)));
+
+        assert_eq!(simplify(expr), expected);
+
+        // c3 | !c3 --> -1
+        let expr = binary_expr(
+            col("c3_non_null"),
+            Operator::BitwiseOr,
+            Expr::Negative(Box::new(col("c3_non_null"))),
+        );
+        let expected = Expr::Literal(ScalarValue::Int64(Some(-1)));
+
+        assert_eq!(simplify(expr), expected);
+    }
+
+    #[test]
+    fn test_simplify_negated_bitwise_xor() {
+        // !c4 ^ c4 --> -1
+        let expr = binary_expr(
+            Expr::Negative(Box::new(col("c4_non_null"))),
+            Operator::BitwiseXor,
+            col("c4_non_null"),
+        );
+        let expected = Expr::Literal(ScalarValue::Int32(Some(-1)));
+
+        assert_eq!(simplify(expr), expected);
+
+        // c4 ^ !c4 --> -1
+        let expr = binary_expr(
+            col("c4_non_null"),
+            Operator::BitwiseXor,
+            Expr::Negative(Box::new(col("c4_non_null"))),
+        );
+        let expected = Expr::Literal(ScalarValue::Int32(Some(-1)));
+
+        assert_eq!(simplify(expr), expected);
+
+        // !c3 ^ c3 --> -1
+        let expr = binary_expr(
+            Expr::Negative(Box::new(col("c3_non_null"))),
+            Operator::BitwiseXor,
+            col("c3_non_null"),
+        );
+        let expected = Expr::Literal(ScalarValue::Int64(Some(-1)));
+
+        assert_eq!(simplify(expr), expected);
+
+        // c3 ^ !c3 --> -1
+        let expr = binary_expr(
+            col("c3_non_null"),
+            Operator::BitwiseXor,
+            Expr::Negative(Box::new(col("c3_non_null"))),
+        );
+        let expected = Expr::Literal(ScalarValue::Int64(Some(-1)));
+
+        assert_eq!(simplify(expr), expected);
+    }
+
+    #[test]
+    fn test_simplify_bitwise_and_or() {
+        // (c2 < 3) & ((c2 < 3) | c1) -> (c2 < 3)
+        let expr = binary_expr(
+            col("c2_non_null").lt(lit(3)),
+            Operator::BitwiseAnd,
+            binary_expr(
+                col("c2_non_null").lt(lit(3)),
+                Operator::BitwiseOr,
+                col("c1_non_null"),
+            ),
+        );
+        let expected = col("c2_non_null").lt(lit(3));
+
+        assert_eq!(simplify(expr), expected);
+    }
+
+    #[test]
+    fn test_simplify_bitwise_or_and() {
+        // (c2 < 3) | ((c2 < 3) & c1) -> (c2 < 3)
+        let expr = binary_expr(
+            col("c2_non_null").lt(lit(3)),
+            Operator::BitwiseOr,
+            binary_expr(
+                col("c2_non_null").lt(lit(3)),
+                Operator::BitwiseAnd,
+                col("c1_non_null"),
+            ),
+        );
+        let expected = col("c2_non_null").lt(lit(3));
+
+        assert_eq!(simplify(expr), expected);
+    }
+
+    #[test]
+    fn test_simplify_simple_bitwise_and() {
+        // (c2 > 5) & (c2 > 5) -> (c2 > 5)
+        let expr = (col("c2").gt(lit(5))).bitwise_and(col("c2").gt(lit(5)));
+        let expected = col("c2").gt(lit(5));
+
+        assert_eq!(simplify(expr), expected);
+    }
+
+    #[test]
+    fn test_simplify_simple_bitwise_or() {
+        // (c2 > 5) | (c2 > 5) -> (c2 > 5)
+        let expr = (col("c2").gt(lit(5))).bitwise_or(col("c2").gt(lit(5)));
+        let expected = col("c2").gt(lit(5));
+
+        assert_eq!(simplify(expr), expected);
+    }
+
+    #[test]
+    fn test_simplify_simple_bitwise_xor() {
+        // c4 ^ c4 -> 0
+        let expr = (col("c4")).bitwise_xor(col("c4"));
+        let expected = Expr::Literal(ScalarValue::UInt32(Some(0)));
+
+        assert_eq!(simplify(expr), expected);
+
+        // c3 ^ c3 -> 0
+        let expr = col("c3").bitwise_xor(col("c3"));
+        let expected = Expr::Literal(ScalarValue::Int64(Some(0)));
+
+        assert_eq!(simplify(expr), expected);
+    }
+
     #[test]
     #[should_panic(
         expected = "called `Result::unwrap()` on an `Err` value: 
ArrowError(DivideByZero)"
@@ -1357,7 +2170,7 @@ mod tests {
 
     #[test]
     fn test_simplify_simple_and() {
-        // (c > 5) AND (c > 5)
+        // (c2 > 5) AND (c2 > 5) -> (c2 > 5)
         let expr = (col("c2").gt(lit(5))).and(col("c2").gt(lit(5)));
         let expected = col("c2").gt(lit(5));
 
@@ -1366,7 +2179,7 @@ mod tests {
 
     #[test]
     fn test_simplify_composed_and() {
-        // ((c > 5) AND (c1 < 6)) AND (c > 5)
+        // ((c2 > 5) AND (c1 < 6)) AND (c2 > 5)
         let expr = binary_expr(
             binary_expr(col("c2").gt(lit(5)), Operator::And, 
col("c1").lt(lit(6))),
             Operator::And,
@@ -1380,7 +2193,7 @@ mod tests {
 
     #[test]
     fn test_simplify_negated_and() {
-        // (c > 5) AND !(c > 5) -- > (c > 5) AND (c <= 5)
+        // (c2 > 5) AND !(c2 > 5) --> (c2 > 5) AND (c2 <= 5)
         let expr = binary_expr(
             col("c2").gt(lit(5)),
             Operator::And,
@@ -1760,8 +2573,12 @@ mod tests {
                 vec![
                     DFField::new(None, "c1", DataType::Utf8, true),
                     DFField::new(None, "c2", DataType::Boolean, true),
+                    DFField::new(None, "c3", DataType::Int64, true),
+                    DFField::new(None, "c4", DataType::UInt32, true),
                     DFField::new(None, "c1_non_null", DataType::Utf8, false),
                     DFField::new(None, "c2_non_null", DataType::Boolean, 
false),
+                    DFField::new(None, "c3_non_null", DataType::Int64, false),
+                    DFField::new(None, "c4_non_null", DataType::UInt32, false),
                 ],
                 HashMap::new(),
             )
@@ -1808,7 +2625,7 @@ mod tests {
         let schema = expr_test_schema();
         assert_eq!(col("c2").get_type(&schema).unwrap(), DataType::Boolean);
 
-        // true = ture -> true
+        // true = true -> true
         assert_eq!(simplify(lit(true).eq(lit(true))), lit(true));
 
         // true = false -> false
diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs 
b/datafusion/optimizer/src/simplify_expressions/utils.rs
index 10e5c0e87..352674c3a 100644
--- a/datafusion/optimizer/src/simplify_expressions/utils.rs
+++ b/datafusion/optimizer/src/simplify_expressions/utils.rs
@@ -77,6 +77,61 @@ pub fn expr_contains(expr: &Expr, needle: &Expr, search_op: 
Operator) -> bool {
     }
 }
 
+/// Deletes all 'needles' or remains one 'needle' that are found in a chain of 
xor
+/// expressions. Such as: A ^ (A ^ (B ^ A))
+pub fn delete_xor_in_complex_expr(expr: &Expr, needle: &Expr, is_left: bool) 
-> Expr {
+    /// Deletes recursively 'needles' in a chain of xor expressions
+    fn recursive_delete_xor_in_expr(
+        expr: &Expr,
+        needle: &Expr,
+        xor_counter: &mut i32,
+    ) -> Expr {
+        match expr {
+            Expr::BinaryExpr(BinaryExpr { left, op, right })
+                if *op == Operator::BitwiseXor =>
+            {
+                let left_expr = recursive_delete_xor_in_expr(left, needle, 
xor_counter);
+                let right_expr = recursive_delete_xor_in_expr(right, needle, 
xor_counter);
+                if left_expr == *needle {
+                    *xor_counter += 1;
+                    return right_expr;
+                } else if right_expr == *needle {
+                    *xor_counter += 1;
+                    return left_expr;
+                }
+
+                Expr::BinaryExpr(BinaryExpr::new(
+                    Box::new(left_expr),
+                    *op,
+                    Box::new(right_expr),
+                ))
+            }
+            _ => expr.clone(),
+        }
+    }
+
+    let mut xor_counter: i32 = 0;
+    let result_expr = recursive_delete_xor_in_expr(expr, needle, &mut 
xor_counter);
+    if result_expr == *needle {
+        return needle.clone();
+    } else if xor_counter % 2 == 0 {
+        if is_left {
+            return Expr::BinaryExpr(BinaryExpr::new(
+                Box::new(needle.clone()),
+                Operator::BitwiseXor,
+                Box::new(result_expr),
+            ));
+        } else {
+            return Expr::BinaryExpr(BinaryExpr::new(
+                Box::new(result_expr),
+                Operator::BitwiseXor,
+                Box::new(needle.clone()),
+            ));
+        }
+    }
+    result_expr
+}
+
 pub fn is_zero(s: &Expr) -> bool {
     match s {
         Expr::Literal(ScalarValue::Int8(Some(0)))
@@ -154,11 +209,16 @@ pub fn is_op_with(target_op: Operator, haystack: &Expr, 
needle: &Expr) -> bool {
     matches!(haystack, Expr::BinaryExpr(BinaryExpr { left, op, right }) if op 
== &target_op && (needle == left.as_ref() || needle == right.as_ref()))
 }
 
-/// returns true if `not_expr` is !`expr`
+/// returns true if `not_expr` is !`expr` (not)
 pub fn is_not_of(not_expr: &Expr, expr: &Expr) -> bool {
     matches!(not_expr, Expr::Not(inner) if expr == inner.as_ref())
 }
 
+/// returns true if `not_expr` is !`expr` (bitwise not)
+pub fn is_negative_of(not_expr: &Expr, expr: &Expr) -> bool {
+    matches!(not_expr, Expr::Negative(inner) if expr == inner.as_ref())
+}
+
 /// returns the contained boolean value in `expr` as
 /// `Expr::Literal(ScalarValue::Boolean(v))`.
 pub fn as_bool_lit(expr: Expr) -> Result<Option<bool>> {


Reply via email to