This is an automated email from the ASF dual-hosted git repository.
xudong963 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 73447b560 simplify the `between` expr during logical plan optimization
(#3404)
73447b560 is described below
commit 73447b560edea20773114f4ed8b49a561b91799d
Author: Kirk Mitchener <[email protected]>
AuthorDate: Fri Sep 9 08:27:55 2022 -0400
simplify the `between` expr during logical plan optimization (#3404)
* rewrite between expression so that it can be further optimized and pushed
down
* update tests
* update for comment and test
* fix common_subexpr_eliminate to retain predictable ordering between runs
---
datafusion/core/tests/sql/predicates.rs | 7 +-
datafusion/core/tests/sql/select.rs | 4 +-
.../optimizer/src/common_subexpr_eliminate.rs | 20 +++---
datafusion/optimizer/src/simplify_expressions.rs | 78 +++++++++++++++++++---
4 files changed, 83 insertions(+), 26 deletions(-)
diff --git a/datafusion/core/tests/sql/predicates.rs
b/datafusion/core/tests/sql/predicates.rs
index 3c11b690d..32365090a 100644
--- a/datafusion/core/tests/sql/predicates.rs
+++ b/datafusion/core/tests/sql/predicates.rs
@@ -427,11 +427,12 @@ async fn multiple_or_predicates() -> Result<()> {
let expected =vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: #lineitem.l_partkey [l_partkey:Int64]",
- " Projection: #part.p_partkey = #lineitem.l_partkey AS
#part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey,
#lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size
[#part.p_partkey =
#lineitem.l_partkey#lineitem.l_partkey#part.p_partkey:Boolean;N,
l_partkey:Int64, l_quantity:Float64, p_brand:Utf8, p_size:Int32]",
- " Filter: #part.p_partkey = #lineitem.l_partkey AND #part.p_brand
= Utf8(\"Brand#12\") AND #lineitem.l_quantity >= CAST(Int64(1) AS Float64) AND
#lineitem.l_quantity <= CAST(Int64(11) AS Float64) AND #part.p_size BETWEEN
Int64(1) AND Int64(5) OR #part.p_brand = Utf8(\"Brand#23\") AND
#lineitem.l_quantity >= CAST(Int64(10) AS Float64) AND #lineitem.l_quantity <=
CAST(Int64(20) AS Float64) AND #part.p_size BETWEEN Int64(1) AND Int64(10) OR
#part.p_brand = Utf8(\"Brand#34\") AN [...]
+ " Projection: #part.p_partkey = #lineitem.l_partkey AS
#part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey,
#part.p_size >= Int32(1) AS #part.p_size >= Int32(1)Int32(1)#part.p_size,
#lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size
[#part.p_partkey =
#lineitem.l_partkey#lineitem.l_partkey#part.p_partkey:Boolean;N, #part.p_size
>= Int32(1)Int32(1)#part.p_size:Boolean;N, l_partkey:Int64, l_quantity:Float64,
p_brand:Utf8, p_size:Int32]",
+ " Filter: #part.p_partkey = #lineitem.l_partkey AND #part.p_brand
= Utf8(\"Brand#12\") AND #lineitem.l_quantity >= CAST(Int64(1) AS Float64) AND
#lineitem.l_quantity <= CAST(Int64(11) AS Float64) AND #part.p_size <= Int32(5)
OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >=
CAST(Int64(10) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(20) AS
Float64) AND #part.p_size <= Int32(10) OR #part.p_brand = Utf8(\"Brand#34\")
AND #lineitem.l_quantity >= CAST(Int64 [...]
" CrossJoin: [l_partkey:Int64, l_quantity:Float64,
p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" TableScan: lineitem projection=[l_partkey, l_quantity]
[l_partkey:Int64, l_quantity:Float64]",
- " TableScan: part projection=[p_partkey, p_brand, p_size]
[p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
+ " Filter: #part.p_size >= Int32(1) [p_partkey:Int64,
p_brand:Utf8, p_size:Int32]",
+ " TableScan: part projection=[p_partkey, p_brand, p_size],
partial_filters=[#part.p_size >= Int32(1)] [p_partkey:Int64, p_brand:Utf8,
p_size:Int32]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
diff --git a/datafusion/core/tests/sql/select.rs
b/datafusion/core/tests/sql/select.rs
index 06353167c..54d1b24e8 100644
--- a/datafusion/core/tests/sql/select.rs
+++ b/datafusion/core/tests/sql/select.rs
@@ -495,10 +495,10 @@ async fn use_between_expression_in_select_query() ->
Result<()> {
.unwrap()
.to_string();
- // Only test that the projection exprs arecorrect, rather than entire
output
+ // Only test that the projection exprs are correct, rather than entire
output
let needle = "ProjectionExec: expr=[c1@0 >= 2 AND c1@0 <= 3 as test.c1
BETWEEN Int64(2) AND Int64(3)]";
assert_contains!(&formatted, needle);
- let needle = "Projection: #test.c1 BETWEEN Int64(2) AND Int64(3)";
+ let needle = "Projection: #test.c1 >= Int64(2) AND #test.c1 <= Int64(3)";
assert_contains!(&formatted, needle);
Ok(())
diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs
b/datafusion/optimizer/src/common_subexpr_eliminate.rs
index 239939f81..978b79d37 100644
--- a/datafusion/optimizer/src/common_subexpr_eliminate.rs
+++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs
@@ -28,7 +28,7 @@ use datafusion_expr::{
utils::from_plan,
Expr, ExprSchemable,
};
-use std::collections::{HashMap, HashSet};
+use std::collections::{BTreeSet, HashMap};
use std::sync::Arc;
/// A map from expression's identifier to tuple including
@@ -271,12 +271,12 @@ fn to_arrays(
/// Build the "intermediate" projection plan that evaluates the extracted
common expressions.
fn build_project_plan(
input: LogicalPlan,
- affected_id: HashSet<Identifier>,
+ affected_id: BTreeSet<Identifier>,
expr_set: &ExprSet,
) -> Result<LogicalPlan> {
let mut project_exprs = vec![];
let mut fields = vec![];
- let mut fields_set = HashSet::new();
+ let mut fields_set = BTreeSet::new();
for id in affected_id {
match expr_set.get(&id) {
@@ -320,7 +320,7 @@ fn rewrite_expr(
expr_set: &mut ExprSet,
optimizer_config: &OptimizerConfig,
) -> Result<(Vec<Vec<Expr>>, LogicalPlan)> {
- let mut affected_id = HashSet::<Identifier>::new();
+ let mut affected_id = BTreeSet::<Identifier>::new();
let rewrote_exprs = exprs_list
.iter()
@@ -482,7 +482,7 @@ struct CommonSubexprRewriter<'a> {
expr_set: &'a mut ExprSet,
id_array: &'a [(usize, Identifier)],
/// Which identifier is replaced.
- affected_id: &'a mut HashSet<Identifier>,
+ affected_id: &'a mut BTreeSet<Identifier>,
/// the max series number we have rewritten. Other expression nodes
/// with smaller series number is already replaced and shouldn't
@@ -561,7 +561,7 @@ fn replace_common_expr(
expr: Expr,
id_array: &[(usize, Identifier)],
expr_set: &mut ExprSet,
- affected_id: &mut HashSet<Identifier>,
+ affected_id: &mut BTreeSet<Identifier>,
) -> Result<Expr> {
expr.rewrite(&mut CommonSubexprRewriter {
expr_set,
@@ -752,7 +752,7 @@ mod test {
#[test]
fn redundant_project_fields() {
let table_scan = test_table_scan().unwrap();
- let affected_id: HashSet<Identifier> =
+ let affected_id: BTreeSet<Identifier> =
["c+a".to_string(), "d+a".to_string()].into_iter().collect();
let expr_set = [
("c+a".to_string(), (col("c+a"), 1, DataType::UInt32)),
@@ -764,7 +764,7 @@ mod test {
build_project_plan(table_scan, affected_id.clone(),
&expr_set).unwrap();
let project_2 = build_project_plan(project, affected_id,
&expr_set).unwrap();
- let mut field_set = HashSet::new();
+ let mut field_set = BTreeSet::new();
for field in project_2.schema().fields() {
assert!(field_set.insert(field.qualified_name()));
}
@@ -779,7 +779,7 @@ mod test {
.unwrap()
.build()
.unwrap();
- let affected_id: HashSet<Identifier> =
+ let affected_id: BTreeSet<Identifier> =
["c+a".to_string(), "d+a".to_string()].into_iter().collect();
let expr_set = [
("c+a".to_string(), (col("c+a"), 1, DataType::UInt32)),
@@ -790,7 +790,7 @@ mod test {
let project = build_project_plan(join, affected_id.clone(),
&expr_set).unwrap();
let project_2 = build_project_plan(project, affected_id,
&expr_set).unwrap();
- let mut field_set = HashSet::new();
+ let mut field_set = BTreeSet::new();
for field in project_2.schema().fields() {
assert!(field_set.insert(field.qualified_name()));
}
diff --git a/datafusion/optimizer/src/simplify_expressions.rs
b/datafusion/optimizer/src/simplify_expressions.rs
index d1afa3543..aa87c5318 100644
--- a/datafusion/optimizer/src/simplify_expressions.rs
+++ b/datafusion/optimizer/src/simplify_expressions.rs
@@ -164,8 +164,6 @@ fn is_op_with(target_op: Operator, haystack: &Expr, needle:
&Expr) -> bool {
/// returns the contained boolean value in `expr` as
/// `Expr::Literal(ScalarValue::Boolean(v))`.
-///
-/// panics if expr is not a literal boolean
fn as_bool_lit(expr: Expr) -> Result<Option<bool>> {
match expr {
Expr::Literal(ScalarValue::Boolean(v)) => Ok(v),
@@ -502,7 +500,7 @@ impl<'a> ConstEvaluator<'a> {
ColumnarValue::Array(a) => {
if a.len() != 1 {
Err(DataFusionError::Execution(format!(
- "Could not evaluate the expressison, found a result of
length {}",
+ "Could not evaluate the expression, found a result of
length {}",
a.len()
)))
} else {
@@ -803,6 +801,27 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a,
S> {
out_expr.rewrite(self)?
}
+ //
+ // Rules for Between
+ //
+
+ // a between 3 and 5 --> a >= 3 AND a <=5
+ // a not between 3 and 5 --> a < 3 OR a > 5
+ Between {
+ expr,
+ low,
+ high,
+ negated,
+ } => {
+ if negated {
+ let l = *expr.clone();
+ let r = *expr;
+ or(l.lt(*low), r.gt(*high))
+ } else {
+ and(expr.clone().gt_eq(*low), expr.lt_eq(*high))
+ }
+ }
+
expr => {
// no additional rewrites possible
expr
@@ -1555,8 +1574,13 @@ mod tests {
high: Box::new(lit(10)),
};
let expr = expr.or(lit_bool_null());
- let result = simplify(expr.clone());
- assert_eq!(expr, result);
+ let result = simplify(expr);
+
+ let expected_expr = or(
+ and(col("c1").gt_eq(lit(0)), col("c1").lt_eq(lit(10))),
+ lit_bool_null(),
+ );
+ assert_eq!(expected_expr, result);
}
#[test]
@@ -1579,8 +1603,8 @@ mod tests {
assert_eq!(simplify(lit_bool_null().and(lit(false))), lit(false),);
// c1 BETWEEN Int32(0) AND Int32(10) AND Boolean(NULL)
- // it can be either NULL or FALSE depending on the value of `c1
BETWEEN Int32(0) AND Int32(10`
- // and should not be rewritten
+ // it can be either NULL or FALSE depending on the value of `c1
BETWEEN Int32(0) AND Int32(10)`
+ // and the Boolean(NULL) should remain
let expr = Expr::Between {
expr: Box::new(col("c1")),
negated: false,
@@ -1588,8 +1612,40 @@ mod tests {
high: Box::new(lit(10)),
};
let expr = expr.and(lit_bool_null());
- let result = simplify(expr.clone());
- assert_eq!(expr, result);
+ let result = simplify(expr);
+
+ let expected_expr = and(
+ and(col("c1").gt_eq(lit(0)), col("c1").lt_eq(lit(10))),
+ lit_bool_null(),
+ );
+ assert_eq!(expected_expr, result);
+ }
+
+ #[test]
+ fn simplify_expr_between() {
+ // c2 between 3 and 4 is c2 >= 3 and c2 <= 4
+ let expr = Expr::Between {
+ expr: Box::new(col("c2")),
+ negated: false,
+ low: Box::new(lit(3)),
+ high: Box::new(lit(4)),
+ };
+ assert_eq!(
+ simplify(expr),
+ and(col("c2").gt_eq(lit(3)), col("c2").lt_eq(lit(4)))
+ );
+
+ // c2 not between 3 and 4 is c2 < 3 or c2 > 4
+ let expr = Expr::Between {
+ expr: Box::new(col("c2")),
+ negated: true,
+ low: Box::new(lit(3)),
+ high: Box::new(lit(4)),
+ };
+ assert_eq!(
+ simplify(expr),
+ or(col("c2").lt(lit(3)), col("c2").gt(lit(4)))
+ );
}
// ------------------------------
@@ -2167,7 +2223,7 @@ mod tests {
.unwrap()
.build()
.unwrap();
- let expected = "Filter: #test.d NOT BETWEEN Int32(1) AND Int32(10) AS
NOT test.d BETWEEN Int32(1) AND Int32(10)\
+ let expected = "Filter: #test.d < Int32(1) OR #test.d > Int32(10) AS
NOT test.d BETWEEN Int32(1) AND Int32(10)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);
@@ -2188,7 +2244,7 @@ mod tests {
.unwrap()
.build()
.unwrap();
- let expected = "Filter: #test.d BETWEEN Int32(1) AND Int32(10) AS NOT
test.d NOT BETWEEN Int32(1) AND Int32(10)\
+ let expected = "Filter: #test.d >= Int32(1) AND #test.d <= Int32(10)
AS NOT test.d NOT BETWEEN Int32(1) AND Int32(10)\
\n TableScan: test";
assert_optimized_plan_eq(&plan, expected);