This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new d7b221213 feat: add optimization rules for bitwise operations (#5423)
d7b221213 is described below
commit d7b221213c767e36c99bc17968eae1323e21b7c4
Author: Igor Izvekov <[email protected]>
AuthorDate: Tue Mar 7 23:07:02 2023 +0300
feat: add optimization rules for bitwise operations (#5423)
* feat: add optimization rules for bitwise operations
* fix: cargo fmt
* fix: comments, determining data type of an int literal, additional tests
* fix: add "get_data_type" to impl SimplifyInfo for MyInfo
* fix: add behavior of the function "get_data_type"
* add information about "get_data_type" to the docs
* fix: doc tests for "expr_simplifier.rs"
* fix: delete_xor_in_complex_expr
* fix: some bitwise optimization rules, that can work even if an expr is
nullable
* fix: prefer ScalarValue methods over "new_int_by_expr_data_type"
* fix: ScalarValue::new_negative_one
* fix: get_data_type
* fix: cargo clippy and fmt style
---
datafusion/common/src/scalar.rs | 20 +-
datafusion/core/tests/simplification.rs | 7 +
datafusion/expr/src/expr.rs | 25 +
datafusion/expr/src/expr_fn.rs | 45 ++
.../optimizer/src/simplify_expressions/context.rs | 18 +
.../src/simplify_expressions/expr_simplifier.rs | 827 ++++++++++++++++++++-
.../optimizer/src/simplify_expressions/utils.rs | 62 +-
7 files changed, 997 insertions(+), 7 deletions(-)
diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs
index 212328121..3f2f4bb39 100644
--- a/datafusion/common/src/scalar.rs
+++ b/datafusion/common/src/scalar.rs
@@ -1019,7 +1019,7 @@ impl ScalarValue {
Self::List(scalars, Box::new(Field::new("item", child_type, true)))
}
- // Create a zero value in the given type.
+ /// Create a zero value in the given type.
pub fn new_zero(datatype: &DataType) -> Result<ScalarValue> {
assert!(datatype.is_primitive());
Ok(match datatype {
@@ -1042,6 +1042,24 @@ impl ScalarValue {
})
}
+ /// Create a negative one value in the given type.
+ pub fn new_negative_one(datatype: &DataType) -> Result<ScalarValue> {
+ assert!(datatype.is_primitive());
+ Ok(match datatype {
+ DataType::Int8 | DataType::UInt8 => ScalarValue::Int8(Some(-1)),
+ DataType::Int16 | DataType::UInt16 => ScalarValue::Int16(Some(-1)),
+ DataType::Int32 | DataType::UInt32 => ScalarValue::Int32(Some(-1)),
+ DataType::Int64 | DataType::UInt64 => ScalarValue::Int64(Some(-1)),
+ DataType::Float32 => ScalarValue::Float32(Some(-1.0)),
+ DataType::Float64 => ScalarValue::Float64(Some(-1.0)),
+ _ => {
+ return Err(DataFusionError::NotImplemented(format!(
+ "Can't create a negative one scalar from data_type
\"{datatype:?}\""
+ )));
+ }
+ })
+ }
+
/// Getter for the `DataType` of the value
pub fn get_datatype(&self) -> DataType {
match self {
diff --git a/datafusion/core/tests/simplification.rs
b/datafusion/core/tests/simplification.rs
index 6e74fc0d9..f6b944b50 100644
--- a/datafusion/core/tests/simplification.rs
+++ b/datafusion/core/tests/simplification.rs
@@ -49,6 +49,13 @@ impl SimplifyInfo for MyInfo {
fn execution_props(&self) -> &ExecutionProps {
&self.execution_props
}
+
+ fn get_data_type(&self, expr: &Expr) -> Result<DataType> {
+ match expr.get_type(&self.schema) {
+ Ok(expr_data_type) => Ok(expr_data_type),
+ Err(e) => Err(e),
+ }
+ }
}
impl From<DFSchema> for MyInfo {
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 8b6c39043..1530513e9 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -639,6 +639,31 @@ impl Expr {
binary_expr(self, Operator::Or, other)
}
+ /// Return `self & other`
+ pub fn bitwise_and(self, other: Expr) -> Expr {
+ binary_expr(self, Operator::BitwiseAnd, other)
+ }
+
+ /// Return `self | other`
+ pub fn bitwise_or(self, other: Expr) -> Expr {
+ binary_expr(self, Operator::BitwiseOr, other)
+ }
+
+ /// Return `self ^ other`
+ pub fn bitwise_xor(self, other: Expr) -> Expr {
+ binary_expr(self, Operator::BitwiseXor, other)
+ }
+
+ /// Return `self >> other`
+ pub fn bitwise_shift_right(self, other: Expr) -> Expr {
+ binary_expr(self, Operator::BitwiseShiftRight, other)
+ }
+
+ /// Return `self << other`
+ pub fn bitwise_shift_left(self, other: Expr) -> Expr {
+ binary_expr(self, Operator::BitwiseShiftLeft, other)
+ }
+
/// Return `!self`
#[allow(clippy::should_implement_trait)]
pub fn not(self) -> Expr {
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index 325fac57c..6465ca80b 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -112,6 +112,51 @@ pub fn count(expr: Expr) -> Expr {
))
}
+/// Return a new expression with bitwise AND
+pub fn bitwise_and(left: Expr, right: Expr) -> Expr {
+ Expr::BinaryExpr(BinaryExpr::new(
+ Box::new(left),
+ Operator::BitwiseAnd,
+ Box::new(right),
+ ))
+}
+
+/// Return a new expression with bitwise OR
+pub fn bitwise_or(left: Expr, right: Expr) -> Expr {
+ Expr::BinaryExpr(BinaryExpr::new(
+ Box::new(left),
+ Operator::BitwiseOr,
+ Box::new(right),
+ ))
+}
+
+/// Return a new expression with bitwise XOR
+pub fn bitwise_xor(left: Expr, right: Expr) -> Expr {
+ Expr::BinaryExpr(BinaryExpr::new(
+ Box::new(left),
+ Operator::BitwiseXor,
+ Box::new(right),
+ ))
+}
+
+/// Return a new expression with bitwise SHIFT RIGHT
+pub fn bitwise_shift_right(left: Expr, right: Expr) -> Expr {
+ Expr::BinaryExpr(BinaryExpr::new(
+ Box::new(left),
+ Operator::BitwiseShiftRight,
+ Box::new(right),
+ ))
+}
+
+/// Return a new expression with bitwise SHIFT LEFT
+pub fn bitwise_shift_left(left: Expr, right: Expr) -> Expr {
+ Expr::BinaryExpr(BinaryExpr::new(
+ Box::new(left),
+ Operator::BitwiseShiftLeft,
+ Box::new(right),
+ ))
+}
+
/// Create an expression to represent the count(distinct) aggregate function
pub fn count_distinct(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
diff --git a/datafusion/optimizer/src/simplify_expressions/context.rs
b/datafusion/optimizer/src/simplify_expressions/context.rs
index 379a803f4..b6e3f07b2 100644
--- a/datafusion/optimizer/src/simplify_expressions/context.rs
+++ b/datafusion/optimizer/src/simplify_expressions/context.rs
@@ -38,6 +38,9 @@ pub trait SimplifyInfo {
/// Returns details needed for partial expression evaluation
fn execution_props(&self) -> &ExecutionProps;
+
+ /// Returns data type of this expr needed for determining optimized int
type of a value
+ fn get_data_type(&self, expr: &Expr) -> Result<DataType>;
}
/// Provides simplification information based on DFSchema and
@@ -123,6 +126,21 @@ impl<'a> SimplifyInfo for SimplifyContext<'a> {
})
}
+ /// Returns data type of this expr needed for determining optimized int
type of a value
+ fn get_data_type(&self, expr: &Expr) -> Result<DataType> {
+ if self.schemas.len() == 1 {
+ match expr.get_type(&self.schemas[0]) {
+ Ok(expr_data_type) => Ok(expr_data_type),
+ Err(e) => Err(e),
+ }
+ } else {
+ Err(DataFusionError::Internal(
+ "The expr has more than one schema, could not determine data
type"
+ .to_string(),
+ ))
+ }
+ }
+
fn execution_props(&self) -> &ExecutionProps {
self.props
}
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index 5421c43da..220d532b0 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -70,6 +70,7 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
/// `b > 2`
///
/// ```
+ /// use arrow::datatypes::DataType;
/// use datafusion_expr::{col, lit, Expr};
/// use datafusion_common::Result;
/// use datafusion_physical_expr::execution_props::ExecutionProps;
@@ -92,6 +93,9 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
/// fn execution_props(&self) -> &ExecutionProps {
/// &self.execution_props
/// }
+ /// fn get_data_type(&self, expr: &Expr) -> Result<DataType> {
+ /// Ok(DataType::Int32)
+ /// }
/// }
///
/// // Create the simplifier
@@ -337,7 +341,8 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a,
S> {
/// rewrite the expression simplifying any constant expressions
fn mutate(&mut self, expr: Expr) -> Result<Expr> {
use datafusion_expr::Operator::{
- And, Divide, Eq, Modulo, Multiply, NotEq, Or, RegexIMatch,
RegexMatch,
+ And, BitwiseAnd, BitwiseOr, BitwiseShiftLeft, BitwiseShiftRight,
BitwiseXor,
+ Divide, Eq, Modulo, Multiply, NotEq, Or, RegexIMatch, RegexMatch,
RegexNotIMatch, RegexNotMatch,
};
@@ -700,6 +705,298 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a,
S> {
return
Err(DataFusionError::ArrowError(ArrowError::DivideByZero));
}
+ //
+ // Rules for BitwiseAnd
+ //
+
+ // A & null -> null
+ Expr::BinaryExpr(BinaryExpr {
+ left: _,
+ op: BitwiseAnd,
+ right,
+ }) if is_null(&right) => *right,
+
+ // null & A -> null
+ Expr::BinaryExpr(BinaryExpr {
+ left,
+ op: BitwiseAnd,
+ right: _,
+ }) if is_null(&left) => *left,
+
+ // A & 0 -> 0 (if A not nullable)
+ Expr::BinaryExpr(BinaryExpr {
+ left,
+ op: BitwiseAnd,
+ right,
+ }) if !info.nullable(&left)? && is_zero(&right) => *right,
+
+ // 0 & A -> 0 (if A not nullable)
+ Expr::BinaryExpr(BinaryExpr {
+ left,
+ op: BitwiseAnd,
+ right,
+ }) if !info.nullable(&right)? && is_zero(&left) => *left,
+
+ // !A & A -> 0 (if A not nullable)
+ Expr::BinaryExpr(BinaryExpr {
+ left,
+ op: BitwiseAnd,
+ right,
+ }) if is_negative_of(&left, &right) && !info.nullable(&right)? => {
+
Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?)
+ }
+
+ // A & !A -> 0 (if A not nullable)
+ Expr::BinaryExpr(BinaryExpr {
+ left,
+ op: BitwiseAnd,
+ right,
+ }) if is_negative_of(&right, &left) && !info.nullable(&left)? => {
+
Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?)
+ }
+
+ // (..A..) & A --> (..A..)
+ Expr::BinaryExpr(BinaryExpr {
+ left,
+ op: BitwiseAnd,
+ right,
+ }) if expr_contains(&left, &right, BitwiseAnd) => *left,
+
+ // A & (..A..) --> (..A..)
+ Expr::BinaryExpr(BinaryExpr {
+ left,
+ op: BitwiseAnd,
+ right,
+ }) if expr_contains(&right, &left, BitwiseAnd) => *right,
+
+ // A & (A | B) --> A (if B not null)
+ Expr::BinaryExpr(BinaryExpr {
+ left,
+ op: BitwiseAnd,
+ right,
+ }) if !info.nullable(&right)? && is_op_with(BitwiseOr, &right,
&left) => {
+ *left
+ }
+
+ // (A | B) & A --> A (if B not null)
+ Expr::BinaryExpr(BinaryExpr {
+ left,
+ op: BitwiseAnd,
+ right,
+ }) if !info.nullable(&left)? && is_op_with(BitwiseOr, &left,
&right) => {
+ *right
+ }
+
+ //
+ // Rules for BitwiseOr
+ //
+
+ // A | null -> null
+ Expr::BinaryExpr(BinaryExpr {
+ left: _,
+ op: BitwiseOr,
+ right,
+ }) if is_null(&right) => *right,
+
+ // null | A -> null
+ Expr::BinaryExpr(BinaryExpr {
+ left,
+ op: BitwiseOr,
+ right: _,
+ }) if is_null(&left) => *left,
+
+ // A | 0 -> A (even if A is null)
+ Expr::BinaryExpr(BinaryExpr {
+ left,
+ op: BitwiseOr,
+ right,
+ }) if is_zero(&right) => *left,
+
+ // 0 | A -> A (even if A is null)
+ Expr::BinaryExpr(BinaryExpr {
+ left,
+ op: BitwiseOr,
+ right,
+ }) if is_zero(&left) => *right,
+
+ // !A | A -> -1 (if A not nullable)
+ Expr::BinaryExpr(BinaryExpr {
+ left,
+ op: BitwiseOr,
+ right,
+ }) if is_negative_of(&left, &right) && !info.nullable(&right)? => {
+
Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?)
+ }
+
+ // A | !A -> -1 (if A not nullable)
+ Expr::BinaryExpr(BinaryExpr {
+ left,
+ op: BitwiseOr,
+ right,
+ }) if is_negative_of(&right, &left) && !info.nullable(&left)? => {
+
Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?)
+ }
+
+ // (..A..) | A --> (..A..)
+ Expr::BinaryExpr(BinaryExpr {
+ left,
+ op: BitwiseOr,
+ right,
+ }) if expr_contains(&left, &right, BitwiseOr) => *left,
+
+ // A | (..A..) --> (..A..)
+ Expr::BinaryExpr(BinaryExpr {
+ left,
+ op: BitwiseOr,
+ right,
+ }) if expr_contains(&right, &left, BitwiseOr) => *right,
+
+ // A | (A & B) --> A (if B not null)
+ Expr::BinaryExpr(BinaryExpr {
+ left,
+ op: BitwiseOr,
+ right,
+ }) if !info.nullable(&right)? && is_op_with(BitwiseAnd, &right,
&left) => {
+ *left
+ }
+
+ // (A & B) | A --> A (if B not null)
+ Expr::BinaryExpr(BinaryExpr {
+ left,
+ op: BitwiseOr,
+ right,
+ }) if !info.nullable(&left)? && is_op_with(BitwiseAnd, &left,
&right) => {
+ *right
+ }
+
+ //
+ // Rules for BitwiseXor
+ //
+
+ // A ^ null -> null
+ Expr::BinaryExpr(BinaryExpr {
+ left: _,
+ op: BitwiseXor,
+ right,
+ }) if is_null(&right) => *right,
+
+ // null ^ A -> null
+ Expr::BinaryExpr(BinaryExpr {
+ left,
+ op: BitwiseXor,
+ right: _,
+ }) if is_null(&left) => *left,
+
+ // A ^ 0 -> A (if A not nullable)
+ Expr::BinaryExpr(BinaryExpr {
+ left,
+ op: BitwiseXor,
+ right,
+ }) if !info.nullable(&left)? && is_zero(&right) => *left,
+
+ // 0 ^ A -> A (if A not nullable)
+ Expr::BinaryExpr(BinaryExpr {
+ left,
+ op: BitwiseXor,
+ right,
+ }) if !info.nullable(&right)? && is_zero(&left) => *right,
+
+ // !A ^ A -> -1 (if A not nullable)
+ Expr::BinaryExpr(BinaryExpr {
+ left,
+ op: BitwiseXor,
+ right,
+ }) if is_negative_of(&left, &right) && !info.nullable(&right)? => {
+
Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?)
+ }
+
+ // A ^ !A -> -1 (if A not nullable)
+ Expr::BinaryExpr(BinaryExpr {
+ left,
+ op: BitwiseXor,
+ right,
+ }) if is_negative_of(&right, &left) && !info.nullable(&left)? => {
+
Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?)
+ }
+
+ // (..A..) ^ A --> (the expression without A, if number of A is
odd, otherwise one A)
+ Expr::BinaryExpr(BinaryExpr {
+ left,
+ op: BitwiseXor,
+ right,
+ }) if expr_contains(&left, &right, BitwiseXor) => {
+ let expr = delete_xor_in_complex_expr(&left, &right, false);
+ if expr == *right {
+
Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&right)?)?)
+ } else {
+ expr
+ }
+ }
+
+ // A ^ (..A..) --> (the expression without A, if number of A is
odd, otherwise one A)
+ Expr::BinaryExpr(BinaryExpr {
+ left,
+ op: BitwiseXor,
+ right,
+ }) if expr_contains(&right, &left, BitwiseXor) => {
+ let expr = delete_xor_in_complex_expr(&right, &left, true);
+ if expr == *left {
+
Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?)
+ } else {
+ expr
+ }
+ }
+
+ //
+ // Rules for BitwiseShiftRight
+ //
+
+ // A >> null -> null
+ Expr::BinaryExpr(BinaryExpr {
+ left: _,
+ op: BitwiseShiftRight,
+ right,
+ }) if is_null(&right) => *right,
+
+ // null >> A -> null
+ Expr::BinaryExpr(BinaryExpr {
+ left,
+ op: BitwiseShiftRight,
+ right: _,
+ }) if is_null(&left) => *left,
+
+ // A >> 0 -> A (even if A is null)
+ Expr::BinaryExpr(BinaryExpr {
+ left,
+ op: BitwiseShiftRight,
+ right,
+ }) if is_zero(&right) => *left,
+
+ //
+ // Rules for BitwiseShiftRight
+ //
+
+ // A << null -> null
+ Expr::BinaryExpr(BinaryExpr {
+ left: _,
+ op: BitwiseShiftLeft,
+ right,
+ }) if is_null(&right) => *right,
+
+ // null << A -> null
+ Expr::BinaryExpr(BinaryExpr {
+ left,
+ op: BitwiseShiftLeft,
+ right: _,
+ }) if is_null(&left) => *left,
+
+ // A << 0 -> A (even if A is null)
+ Expr::BinaryExpr(BinaryExpr {
+ left,
+ op: BitwiseShiftLeft,
+ right,
+ }) if is_zero(&right) => *left,
+
//
// Rules for Not
//
@@ -1346,6 +1643,522 @@ mod tests {
assert_eq!(simplify(expr), expected);
}
+ #[test]
+ fn test_simplify_bitwise_xor_by_null() {
+ let null = Expr::Literal(ScalarValue::Null);
+ // A ^ null --> null
+ {
+ let expr = binary_expr(col("c2"), Operator::BitwiseXor,
null.clone());
+ assert_eq!(simplify(expr), null);
+ }
+ // null ^ A --> null
+ {
+ let expr = binary_expr(null.clone(), Operator::BitwiseXor,
col("c2"));
+ assert_eq!(simplify(expr), null);
+ }
+ }
+
+ #[test]
+ fn test_simplify_bitwise_shift_right_by_null() {
+ let null = Expr::Literal(ScalarValue::Null);
+ // A >> null --> null
+ {
+ let expr = binary_expr(col("c2"), Operator::BitwiseShiftRight,
null.clone());
+ assert_eq!(simplify(expr), null);
+ }
+ // null >> A --> null
+ {
+ let expr = binary_expr(null.clone(), Operator::BitwiseShiftRight,
col("c2"));
+ assert_eq!(simplify(expr), null);
+ }
+ }
+
+ #[test]
+ fn test_simplify_bitwise_shift_left_by_null() {
+ let null = Expr::Literal(ScalarValue::Null);
+ // A << null --> null
+ {
+ let expr = binary_expr(col("c2"), Operator::BitwiseShiftLeft,
null.clone());
+ assert_eq!(simplify(expr), null);
+ }
+ // null << A --> null
+ {
+ let expr = binary_expr(null.clone(), Operator::BitwiseShiftLeft,
col("c2"));
+ assert_eq!(simplify(expr), null);
+ }
+ }
+
+ #[test]
+ fn test_simplify_bitwise_and_by_zero() {
+ // A & 0 --> 0
+ {
+ let expr = binary_expr(col("c2_non_null"), Operator::BitwiseAnd,
lit(0));
+ assert_eq!(simplify(expr), lit(0));
+ }
+ // 0 & A --> 0
+ {
+ let expr = binary_expr(lit(0), Operator::BitwiseAnd,
col("c2_non_null"));
+ assert_eq!(simplify(expr), lit(0));
+ }
+ }
+
+ #[test]
+ fn test_simplify_bitwise_or_by_zero() {
+ // A | 0 --> A
+ {
+ let expr = binary_expr(col("c2_non_null"), Operator::BitwiseOr,
lit(0));
+ assert_eq!(simplify(expr), col("c2_non_null"));
+ }
+ // 0 | A --> A
+ {
+ let expr = binary_expr(lit(0), Operator::BitwiseOr,
col("c2_non_null"));
+ assert_eq!(simplify(expr), col("c2_non_null"));
+ }
+ }
+
+ #[test]
+ fn test_simplify_bitwise_xor_by_zero() {
+ // A ^ 0 --> A
+ {
+ let expr = binary_expr(col("c2_non_null"), Operator::BitwiseXor,
lit(0));
+ assert_eq!(simplify(expr), col("c2_non_null"));
+ }
+ // 0 ^ A --> A
+ {
+ let expr = binary_expr(lit(0), Operator::BitwiseXor,
col("c2_non_null"));
+ assert_eq!(simplify(expr), col("c2_non_null"));
+ }
+ }
+
+ #[test]
+ fn test_simplify_bitwise_bitwise_shift_right_by_zero() {
+ // A >> 0 --> A
+ {
+ let expr =
+ binary_expr(col("c2_non_null"), Operator::BitwiseShiftRight,
lit(0));
+ assert_eq!(simplify(expr), col("c2_non_null"));
+ }
+ }
+
+ #[test]
+ fn test_simplify_bitwise_bitwise_shift_left_by_zero() {
+ // A << 0 --> A
+ {
+ let expr =
+ binary_expr(col("c2_non_null"), Operator::BitwiseShiftLeft,
lit(0));
+ assert_eq!(simplify(expr), col("c2_non_null"));
+ }
+ }
+
+ #[test]
+ fn test_simplify_bitwise_and_by_null() {
+ let null = Expr::Literal(ScalarValue::Null);
+ // A & null --> null
+ {
+ let expr = binary_expr(col("c2"), Operator::BitwiseAnd,
null.clone());
+ assert_eq!(simplify(expr), null);
+ }
+ // null & A --> null
+ {
+ let expr = binary_expr(null.clone(), Operator::BitwiseAnd,
col("c2"));
+ assert_eq!(simplify(expr), null);
+ }
+ }
+
+ #[test]
+ fn test_simplify_composed_bitwise_and() {
+ // ((c2 > 5) & (c1 < 6)) & (c2 > 5) --> (c2 > 5) & (c1 < 6)
+
+ let expr = binary_expr(
+ binary_expr(
+ col("c2").gt(lit(5)),
+ Operator::BitwiseAnd,
+ col("c1").lt(lit(6)),
+ ),
+ Operator::BitwiseAnd,
+ col("c2").gt(lit(5)),
+ );
+ let expected = binary_expr(
+ col("c2").gt(lit(5)),
+ Operator::BitwiseAnd,
+ col("c1").lt(lit(6)),
+ );
+
+ assert_eq!(simplify(expr), expected);
+
+ // (c2 > 5) & ((c2 > 5) & (c1 < 6)) --> (c2 > 5) & (c1 < 6)
+
+ let expr = binary_expr(
+ col("c2").gt(lit(5)),
+ Operator::BitwiseAnd,
+ binary_expr(
+ col("c2").gt(lit(5)),
+ Operator::BitwiseAnd,
+ col("c1").lt(lit(6)),
+ ),
+ );
+ let expected = binary_expr(
+ col("c2").gt(lit(5)),
+ Operator::BitwiseAnd,
+ col("c1").lt(lit(6)),
+ );
+ assert_eq!(simplify(expr), expected);
+ }
+
+ #[test]
+ fn test_simplify_composed_bitwise_or() {
+ // ((c2 > 5) | (c1 < 6)) | (c2 > 5) --> (c2 > 5) | (c1 < 6)
+
+ let expr = binary_expr(
+ binary_expr(
+ col("c2").gt(lit(5)),
+ Operator::BitwiseOr,
+ col("c1").lt(lit(6)),
+ ),
+ Operator::BitwiseOr,
+ col("c2").gt(lit(5)),
+ );
+ let expected = binary_expr(
+ col("c2").gt(lit(5)),
+ Operator::BitwiseOr,
+ col("c1").lt(lit(6)),
+ );
+
+ assert_eq!(simplify(expr), expected);
+
+ // (c2 > 5) | ((c2 > 5) | (c1 < 6)) --> (c2 > 5) | (c1 < 6)
+
+ let expr = binary_expr(
+ col("c2").gt(lit(5)),
+ Operator::BitwiseOr,
+ binary_expr(
+ col("c2").gt(lit(5)),
+ Operator::BitwiseOr,
+ col("c1").lt(lit(6)),
+ ),
+ );
+ let expected = binary_expr(
+ col("c2").gt(lit(5)),
+ Operator::BitwiseOr,
+ col("c1").lt(lit(6)),
+ );
+
+ assert_eq!(simplify(expr), expected);
+ }
+
+ #[test]
+ fn test_simplify_composed_bitwise_xor() {
+ // with an even number of the column "c2"
+ // c2 ^ ((c2 ^ (c2 | c1)) ^ (c1 & c2)) --> (c2 | c1) ^ (c1 & c2)
+
+ let expr = binary_expr(
+ col("c2"),
+ Operator::BitwiseXor,
+ binary_expr(
+ binary_expr(
+ col("c2"),
+ Operator::BitwiseXor,
+ binary_expr(col("c2"), Operator::BitwiseOr, col("c1")),
+ ),
+ Operator::BitwiseXor,
+ binary_expr(col("c1"), Operator::BitwiseAnd, col("c2")),
+ ),
+ );
+
+ let expected = binary_expr(
+ binary_expr(col("c2"), Operator::BitwiseOr, col("c1")),
+ Operator::BitwiseXor,
+ binary_expr(col("c1"), Operator::BitwiseAnd, col("c2")),
+ );
+
+ assert_eq!(simplify(expr), expected);
+
+ // with an odd number of the column "c2"
+ // c2 ^ (c2 ^ (c2 | c1)) ^ ((c1 & c2) ^ c2) --> c2 ^ ((c2 | c1) ^ (c1
& c2))
+
+ let expr = binary_expr(
+ col("c2"),
+ Operator::BitwiseXor,
+ binary_expr(
+ binary_expr(
+ col("c2"),
+ Operator::BitwiseXor,
+ binary_expr(col("c2"), Operator::BitwiseOr, col("c1")),
+ ),
+ Operator::BitwiseXor,
+ binary_expr(
+ binary_expr(col("c1"), Operator::BitwiseAnd, col("c2")),
+ Operator::BitwiseXor,
+ col("c2"),
+ ),
+ ),
+ );
+
+ let expected = binary_expr(
+ col("c2"),
+ Operator::BitwiseXor,
+ binary_expr(
+ binary_expr(col("c2"), Operator::BitwiseOr, col("c1")),
+ Operator::BitwiseXor,
+ binary_expr(col("c1"), Operator::BitwiseAnd, col("c2")),
+ ),
+ );
+
+ assert_eq!(simplify(expr), expected);
+
+ // with an even number of the column "c2"
+ // ((c2 ^ (c2 | c1)) ^ (c1 & c2)) ^ c2 --> (c2 | c1) ^ (c1 & c2)
+
+ let expr = binary_expr(
+ binary_expr(
+ binary_expr(
+ col("c2"),
+ Operator::BitwiseXor,
+ binary_expr(col("c2"), Operator::BitwiseOr, col("c1")),
+ ),
+ Operator::BitwiseXor,
+ binary_expr(col("c1"), Operator::BitwiseAnd, col("c2")),
+ ),
+ Operator::BitwiseXor,
+ col("c2"),
+ );
+
+ let expected = binary_expr(
+ binary_expr(col("c2"), Operator::BitwiseOr, col("c1")),
+ Operator::BitwiseXor,
+ binary_expr(col("c1"), Operator::BitwiseAnd, col("c2")),
+ );
+
+ assert_eq!(simplify(expr), expected);
+
+ // with an odd number of the column "c2"
+ // (c2 ^ (c2 | c1)) ^ ((c1 & c2) ^ c2) ^ c2 --> ((c2 | c1) ^ (c1 &
c2)) ^ c2
+
+ let expr = binary_expr(
+ binary_expr(
+ binary_expr(
+ col("c2"),
+ Operator::BitwiseXor,
+ binary_expr(col("c2"), Operator::BitwiseOr, col("c1")),
+ ),
+ Operator::BitwiseXor,
+ binary_expr(
+ binary_expr(col("c1"), Operator::BitwiseAnd, col("c2")),
+ Operator::BitwiseXor,
+ col("c2"),
+ ),
+ ),
+ Operator::BitwiseXor,
+ col("c2"),
+ );
+
+ let expected = binary_expr(
+ binary_expr(
+ binary_expr(col("c2"), Operator::BitwiseOr, col("c1")),
+ Operator::BitwiseXor,
+ binary_expr(col("c1"), Operator::BitwiseAnd, col("c2")),
+ ),
+ Operator::BitwiseXor,
+ col("c2"),
+ );
+
+ assert_eq!(simplify(expr), expected);
+ }
+
+ #[test]
+ fn test_simplify_negated_bitwise_and() {
+ // !c4 & c4 --> 0
+ let expr = binary_expr(
+ Expr::Negative(Box::new(col("c4_non_null"))),
+ Operator::BitwiseAnd,
+ col("c4_non_null"),
+ );
+ let expected = Expr::Literal(ScalarValue::UInt32(Some(0)));
+
+ assert_eq!(simplify(expr), expected);
+ // c4 & !c4 --> 0
+ let expr = binary_expr(
+ col("c4_non_null"),
+ Operator::BitwiseAnd,
+ Expr::Negative(Box::new(col("c4_non_null"))),
+ );
+ let expected = Expr::Literal(ScalarValue::UInt32(Some(0)));
+
+ assert_eq!(simplify(expr), expected);
+
+ // !c3 & c3 --> 0
+ let expr = binary_expr(
+ Expr::Negative(Box::new(col("c3_non_null"))),
+ Operator::BitwiseAnd,
+ col("c3_non_null"),
+ );
+ let expected = Expr::Literal(ScalarValue::Int64(Some(0)));
+
+ assert_eq!(simplify(expr), expected);
+ // c3 & !c3 --> 0
+ let expr = binary_expr(
+ col("c3_non_null"),
+ Operator::BitwiseAnd,
+ Expr::Negative(Box::new(col("c3_non_null"))),
+ );
+ let expected = Expr::Literal(ScalarValue::Int64(Some(0)));
+
+ assert_eq!(simplify(expr), expected);
+ }
+
+ #[test]
+ fn test_simplify_negated_bitwise_or() {
+ // !c4 | c4 --> -1
+ let expr = binary_expr(
+ Expr::Negative(Box::new(col("c4_non_null"))),
+ Operator::BitwiseOr,
+ col("c4_non_null"),
+ );
+ let expected = Expr::Literal(ScalarValue::Int32(Some(-1)));
+
+ assert_eq!(simplify(expr), expected);
+
+ // c4 | !c4 --> -1
+ let expr = binary_expr(
+ col("c4_non_null"),
+ Operator::BitwiseOr,
+ Expr::Negative(Box::new(col("c4_non_null"))),
+ );
+ let expected = Expr::Literal(ScalarValue::Int32(Some(-1)));
+
+ assert_eq!(simplify(expr), expected);
+
+ // !c3 | c3 --> -1
+ let expr = binary_expr(
+ Expr::Negative(Box::new(col("c3_non_null"))),
+ Operator::BitwiseOr,
+ col("c3_non_null"),
+ );
+ let expected = Expr::Literal(ScalarValue::Int64(Some(-1)));
+
+ assert_eq!(simplify(expr), expected);
+
+ // c3 | !c3 --> -1
+ let expr = binary_expr(
+ col("c3_non_null"),
+ Operator::BitwiseOr,
+ Expr::Negative(Box::new(col("c3_non_null"))),
+ );
+ let expected = Expr::Literal(ScalarValue::Int64(Some(-1)));
+
+ assert_eq!(simplify(expr), expected);
+ }
+
+ #[test]
+ fn test_simplify_negated_bitwise_xor() {
+ // !c4 ^ c4 --> -1
+ let expr = binary_expr(
+ Expr::Negative(Box::new(col("c4_non_null"))),
+ Operator::BitwiseXor,
+ col("c4_non_null"),
+ );
+ let expected = Expr::Literal(ScalarValue::Int32(Some(-1)));
+
+ assert_eq!(simplify(expr), expected);
+
+ // c4 ^ !c4 --> -1
+ let expr = binary_expr(
+ col("c4_non_null"),
+ Operator::BitwiseXor,
+ Expr::Negative(Box::new(col("c4_non_null"))),
+ );
+ let expected = Expr::Literal(ScalarValue::Int32(Some(-1)));
+
+ assert_eq!(simplify(expr), expected);
+
+ // !c3 ^ c3 --> -1
+ let expr = binary_expr(
+ Expr::Negative(Box::new(col("c3_non_null"))),
+ Operator::BitwiseXor,
+ col("c3_non_null"),
+ );
+ let expected = Expr::Literal(ScalarValue::Int64(Some(-1)));
+
+ assert_eq!(simplify(expr), expected);
+
+ // c3 ^ !c3 --> -1
+ let expr = binary_expr(
+ col("c3_non_null"),
+ Operator::BitwiseXor,
+ Expr::Negative(Box::new(col("c3_non_null"))),
+ );
+ let expected = Expr::Literal(ScalarValue::Int64(Some(-1)));
+
+ assert_eq!(simplify(expr), expected);
+ }
+
+ #[test]
+ fn test_simplify_bitwise_and_or() {
+ // (c2 < 3) & ((c2 < 3) | c1) -> (c2 < 3)
+ let expr = binary_expr(
+ col("c2_non_null").lt(lit(3)),
+ Operator::BitwiseAnd,
+ binary_expr(
+ col("c2_non_null").lt(lit(3)),
+ Operator::BitwiseOr,
+ col("c1_non_null"),
+ ),
+ );
+ let expected = col("c2_non_null").lt(lit(3));
+
+ assert_eq!(simplify(expr), expected);
+ }
+
+ #[test]
+ fn test_simplify_bitwise_or_and() {
+ // (c2 < 3) | ((c2 < 3) & c1) -> (c2 < 3)
+ let expr = binary_expr(
+ col("c2_non_null").lt(lit(3)),
+ Operator::BitwiseOr,
+ binary_expr(
+ col("c2_non_null").lt(lit(3)),
+ Operator::BitwiseAnd,
+ col("c1_non_null"),
+ ),
+ );
+ let expected = col("c2_non_null").lt(lit(3));
+
+ assert_eq!(simplify(expr), expected);
+ }
+
+ #[test]
+ fn test_simplify_simple_bitwise_and() {
+ // (c2 > 5) & (c2 > 5) -> (c2 > 5)
+ let expr = (col("c2").gt(lit(5))).bitwise_and(col("c2").gt(lit(5)));
+ let expected = col("c2").gt(lit(5));
+
+ assert_eq!(simplify(expr), expected);
+ }
+
+ #[test]
+ fn test_simplify_simple_bitwise_or() {
+ // (c2 > 5) | (c2 > 5) -> (c2 > 5)
+ let expr = (col("c2").gt(lit(5))).bitwise_or(col("c2").gt(lit(5)));
+ let expected = col("c2").gt(lit(5));
+
+ assert_eq!(simplify(expr), expected);
+ }
+
+ #[test]
+ fn test_simplify_simple_bitwise_xor() {
+ // c4 ^ c4 -> 0
+ let expr = (col("c4")).bitwise_xor(col("c4"));
+ let expected = Expr::Literal(ScalarValue::UInt32(Some(0)));
+
+ assert_eq!(simplify(expr), expected);
+
+ // c3 ^ c3 -> 0
+ let expr = col("c3").bitwise_xor(col("c3"));
+ let expected = Expr::Literal(ScalarValue::Int64(Some(0)));
+
+ assert_eq!(simplify(expr), expected);
+ }
+
#[test]
#[should_panic(
expected = "called `Result::unwrap()` on an `Err` value:
ArrowError(DivideByZero)"
@@ -1357,7 +2170,7 @@ mod tests {
#[test]
fn test_simplify_simple_and() {
- // (c > 5) AND (c > 5)
+ // (c2 > 5) AND (c2 > 5) -> (c2 > 5)
let expr = (col("c2").gt(lit(5))).and(col("c2").gt(lit(5)));
let expected = col("c2").gt(lit(5));
@@ -1366,7 +2179,7 @@ mod tests {
#[test]
fn test_simplify_composed_and() {
- // ((c > 5) AND (c1 < 6)) AND (c > 5)
+ // ((c2 > 5) AND (c1 < 6)) AND (c2 > 5)
let expr = binary_expr(
binary_expr(col("c2").gt(lit(5)), Operator::And,
col("c1").lt(lit(6))),
Operator::And,
@@ -1380,7 +2193,7 @@ mod tests {
#[test]
fn test_simplify_negated_and() {
- // (c > 5) AND !(c > 5) -- > (c > 5) AND (c <= 5)
+ // (c2 > 5) AND !(c2 > 5) --> (c2 > 5) AND (c2 <= 5)
let expr = binary_expr(
col("c2").gt(lit(5)),
Operator::And,
@@ -1760,8 +2573,12 @@ mod tests {
vec![
DFField::new(None, "c1", DataType::Utf8, true),
DFField::new(None, "c2", DataType::Boolean, true),
+ DFField::new(None, "c3", DataType::Int64, true),
+ DFField::new(None, "c4", DataType::UInt32, true),
DFField::new(None, "c1_non_null", DataType::Utf8, false),
DFField::new(None, "c2_non_null", DataType::Boolean,
false),
+ DFField::new(None, "c3_non_null", DataType::Int64, false),
+ DFField::new(None, "c4_non_null", DataType::UInt32, false),
],
HashMap::new(),
)
@@ -1808,7 +2625,7 @@ mod tests {
let schema = expr_test_schema();
assert_eq!(col("c2").get_type(&schema).unwrap(), DataType::Boolean);
- // true = ture -> true
+ // true = true -> true
assert_eq!(simplify(lit(true).eq(lit(true))), lit(true));
// true = false -> false
diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs
b/datafusion/optimizer/src/simplify_expressions/utils.rs
index 10e5c0e87..352674c3a 100644
--- a/datafusion/optimizer/src/simplify_expressions/utils.rs
+++ b/datafusion/optimizer/src/simplify_expressions/utils.rs
@@ -77,6 +77,61 @@ pub fn expr_contains(expr: &Expr, needle: &Expr, search_op:
Operator) -> bool {
}
}
+/// Deletes all 'needles' or remains one 'needle' that are found in a chain of
xor
+/// expressions. Such as: A ^ (A ^ (B ^ A))
+pub fn delete_xor_in_complex_expr(expr: &Expr, needle: &Expr, is_left: bool)
-> Expr {
+ /// Deletes recursively 'needles' in a chain of xor expressions
+ fn recursive_delete_xor_in_expr(
+ expr: &Expr,
+ needle: &Expr,
+ xor_counter: &mut i32,
+ ) -> Expr {
+ match expr {
+ Expr::BinaryExpr(BinaryExpr { left, op, right })
+ if *op == Operator::BitwiseXor =>
+ {
+ let left_expr = recursive_delete_xor_in_expr(left, needle,
xor_counter);
+ let right_expr = recursive_delete_xor_in_expr(right, needle,
xor_counter);
+ if left_expr == *needle {
+ *xor_counter += 1;
+ return right_expr;
+ } else if right_expr == *needle {
+ *xor_counter += 1;
+ return left_expr;
+ }
+
+ Expr::BinaryExpr(BinaryExpr::new(
+ Box::new(left_expr),
+ *op,
+ Box::new(right_expr),
+ ))
+ }
+ _ => expr.clone(),
+ }
+ }
+
+ let mut xor_counter: i32 = 0;
+ let result_expr = recursive_delete_xor_in_expr(expr, needle, &mut
xor_counter);
+ if result_expr == *needle {
+ return needle.clone();
+ } else if xor_counter % 2 == 0 {
+ if is_left {
+ return Expr::BinaryExpr(BinaryExpr::new(
+ Box::new(needle.clone()),
+ Operator::BitwiseXor,
+ Box::new(result_expr),
+ ));
+ } else {
+ return Expr::BinaryExpr(BinaryExpr::new(
+ Box::new(result_expr),
+ Operator::BitwiseXor,
+ Box::new(needle.clone()),
+ ));
+ }
+ }
+ result_expr
+}
+
pub fn is_zero(s: &Expr) -> bool {
match s {
Expr::Literal(ScalarValue::Int8(Some(0)))
@@ -154,11 +209,16 @@ pub fn is_op_with(target_op: Operator, haystack: &Expr,
needle: &Expr) -> bool {
matches!(haystack, Expr::BinaryExpr(BinaryExpr { left, op, right }) if op
== &target_op && (needle == left.as_ref() || needle == right.as_ref()))
}
-/// returns true if `not_expr` is !`expr`
+/// returns true if `not_expr` is !`expr` (not)
pub fn is_not_of(not_expr: &Expr, expr: &Expr) -> bool {
matches!(not_expr, Expr::Not(inner) if expr == inner.as_ref())
}
+/// returns true if `not_expr` is !`expr` (bitwise not)
+pub fn is_negative_of(not_expr: &Expr, expr: &Expr) -> bool {
+ matches!(not_expr, Expr::Negative(inner) if expr == inner.as_ref())
+}
+
/// returns the contained boolean value in `expr` as
/// `Expr::Literal(ScalarValue::Boolean(v))`.
pub fn as_bool_lit(expr: Expr) -> Result<Option<bool>> {