This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 57e445aad Revert "Factorize common AND factors out of OR predicates to
support filterPu… (#3859)" (#3897)
57e445aad is described below
commit 57e445aadcc87cad33de8a969eb4203b219ec9dd
Author: Andrew Lamb <[email protected]>
AuthorDate: Wed Oct 19 16:19:23 2022 -0400
Revert "Factorize common AND factors out of OR predicates to support
filterPu… (#3859)" (#3897)
This reverts commit ddfd052ceef484dd36b5803dee5530a24e504893.
---
benchmarks/expected-plans/q7.txt | 2 +-
.../src/physical_plan/file_format/row_filter.rs | 4 +-
datafusion/core/tests/sql/joins.rs | 7 +-
datafusion/optimizer/src/filter_push_down.rs | 37 +--
datafusion/optimizer/src/utils.rs | 270 ++-------------------
5 files changed, 25 insertions(+), 295 deletions(-)
diff --git a/benchmarks/expected-plans/q7.txt b/benchmarks/expected-plans/q7.txt
index 73fe8574a..a1d1806f9 100644
--- a/benchmarks/expected-plans/q7.txt
+++ b/benchmarks/expected-plans/q7.txt
@@ -3,7 +3,7 @@ Sort: shipping.supp_nation ASC NULLS LAST, shipping.cust_nation
ASC NULLS LAST,
Aggregate: groupBy=[[shipping.supp_nation, shipping.cust_nation,
shipping.l_year]], aggr=[[SUM(shipping.volume)]]
Projection: shipping.supp_nation, shipping.cust_nation, shipping.l_year,
shipping.volume, alias=shipping
Projection: n1.n_name AS supp_nation, n2.n_name AS cust_nation,
datepart(Utf8("YEAR"), lineitem.l_shipdate) AS l_year,
CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) *
CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23,
2)) AS Decimal128(38, 4)) AS volume, alias=shipping
- Filter: n1.n_name = Utf8("FRANCE") OR n2.n_name = Utf8("FRANCE") AND
n2.n_name = Utf8("GERMANY") OR n1.n_name = Utf8("GERMANY")
+ Filter: n1.n_name = Utf8("FRANCE") AND n2.n_name = Utf8("GERMANY")
OR n1.n_name = Utf8("GERMANY") AND n2.n_name = Utf8("FRANCE")
Inner Join: customer.c_nationkey = n2.n_nationkey
Inner Join: supplier.s_nationkey = n1.n_nationkey
Inner Join: orders.o_custkey = customer.c_custkey
diff --git a/datafusion/core/src/physical_plan/file_format/row_filter.rs
b/datafusion/core/src/physical_plan/file_format/row_filter.rs
index 2ac55d368..dd9c8fb65 100644
--- a/datafusion/core/src/physical_plan/file_format/row_filter.rs
+++ b/datafusion/core/src/physical_plan/file_format/row_filter.rs
@@ -22,7 +22,7 @@ use arrow::record_batch::RecordBatch;
use datafusion_common::{Column, DataFusionError, Result, ScalarValue,
ToDFSchema};
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter,
RewriteRecursion};
-use datafusion_expr::{Expr, Operator};
+use datafusion_expr::Expr;
use datafusion_optimizer::utils::split_conjunction_owned;
use datafusion_physical_expr::execution_props::ExecutionProps;
use datafusion_physical_expr::{create_physical_expr, PhysicalExpr};
@@ -253,7 +253,7 @@ pub fn build_row_filter(
metadata: &ParquetMetaData,
reorder_predicates: bool,
) -> Result<Option<RowFilter>> {
- let predicates = split_conjunction_owned(expr, Operator::And);
+ let predicates = split_conjunction_owned(expr);
let mut candidates: Vec<FilterCandidate> = predicates
.into_iter()
diff --git a/datafusion/core/tests/sql/joins.rs
b/datafusion/core/tests/sql/joins.rs
index 1ba8cf7ac..2ff4947b3 100644
--- a/datafusion/core/tests/sql/joins.rs
+++ b/datafusion/core/tests/sql/joins.rs
@@ -1468,15 +1468,10 @@ async fn reduce_left_join_2() -> Result<()> {
.expect(&msg);
let state = ctx.state();
let plan = state.optimize(&plan)?;
-
- // filter expr: `t2.t2_int < 10 or (t1.t1_int > 2 and t2.t2_name != 'w')`
- // could be write to: `(t1.t1_int > 2 or t2.t2_int < 10) and (t2.t2_name
!= 'w' or t2.t2_int < 10)`
- // the right part `(t2.t2_name != 'w' or t2.t2_int < 10)` could be push
down left join side and remove in filter.
-
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, t2.t2_name,
t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N,
t2_name:Utf8;N, t2_int:UInt32;N]",
- " Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR CAST(t1.t1_int AS
Int64) > Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N,
t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR CAST(t1.t1_int AS
Int64) > Int64(2) AND t2.t2_name != Utf8(\"w\") [t1_id:UInt32;N,
t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N,
t2_int:UInt32;N]",
" Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N,
t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N,
t2_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int]
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR t2.t2_name !=
Utf8(\"w\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
diff --git a/datafusion/optimizer/src/filter_push_down.rs
b/datafusion/optimizer/src/filter_push_down.rs
index a768b0a7f..6396f1fbf 100644
--- a/datafusion/optimizer/src/filter_push_down.rs
+++ b/datafusion/optimizer/src/filter_push_down.rs
@@ -14,7 +14,6 @@
//! Filter Push Down optimizer rule ensures that filters are applied as early
as possible in the plan
-use crate::utils::{split_conjunction, CnfHelper};
use crate::{utils, OptimizerConfig, OptimizerRule};
use datafusion_common::{Column, DFSchema, DataFusionError, Result};
use datafusion_expr::{
@@ -29,7 +28,6 @@ use datafusion_expr::{
utils::{expr_to_columns, exprlist_to_columns, from_plan},
Expr, Operator, TableProviderFilterPushDown,
};
-use log::error;
use std::collections::{HashMap, HashSet};
use std::iter::once;
@@ -532,14 +530,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) ->
Result<LogicalPlan> {
}
LogicalPlan::Analyze { .. } => push_down(&state, plan),
LogicalPlan::Filter(filter) => {
- let filter_cnf = filter.predicate().clone().rewrite(&mut
CnfHelper::new());
- let predicates = match filter_cnf {
- Ok(ref expr) => split_conjunction(expr),
- Err(e) => {
- error!("Fail at CnfHelper rewrite: {}.", e);
- split_conjunction(filter.predicate())
- }
- };
+ let predicates = utils::split_conjunction(filter.predicate());
predicates
.into_iter()
@@ -962,30 +953,6 @@ mod tests {
Ok(())
}
- #[test]
- fn filter_keep_partial_agg() -> Result<()> {
- let table_scan = test_table_scan()?;
- let f1 = col("c").eq(lit(1i64)).and(col("b").gt(lit(2i64)));
- let f2 = col("c").eq(lit(1i64)).and(col("b").gt(lit(3i64)));
- let filter = f1.or(f2);
- let plan = LogicalPlanBuilder::from(table_scan)
- .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
- .filter(filter)?
- .build()?;
- // filter of aggregate is after aggregation since they are
non-commutative
- // (c =1 AND b > 2) OR (c = 1 AND b > 3)
- // rewrite to CNF
- // (c = 1 OR c = 1) [can pushDown] AND (c = 1 OR b > 3) AND (b > 2 OR
C = 1) AND (b > 2 OR b > 3)
-
- let expected = "\
- Filter: test.c = Int64(1) OR b > Int64(3) AND b > Int64(2) OR
test.c = Int64(1) AND b > Int64(2) OR b > Int64(3)\
- \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\
- \n Filter: test.c = Int64(1) OR test.c = Int64(1)\
- \n TableScan: test";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
- }
-
/// verifies that a filter is pushed to before a projection, the filter
expression is correctly re-written
#[test]
fn alias() -> Result<()> {
@@ -2377,7 +2344,7 @@ mod tests {
.filter(filter)?
.build()?;
- let expected = "Filter: test.a = d OR test.b = e AND test.a = d OR
test.c < UInt32(10) AND test.b > UInt32(1) OR test.b = e\
+ let expected = "Filter: test.a = d AND test.b > UInt32(1) OR test.b =
e AND test.c < UInt32(10)\
\n CrossJoin:\
\n Projection: test.a, test.b, test.c\
\n Filter: test.b > UInt32(1) OR test.c <
UInt32(10)\
diff --git a/datafusion/optimizer/src/utils.rs
b/datafusion/optimizer/src/utils.rs
index f088085b8..130df3e0e 100644
--- a/datafusion/optimizer/src/utils.rs
+++ b/datafusion/optimizer/src/utils.rs
@@ -21,7 +21,7 @@ use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::Result;
use datafusion_common::{plan_err, Column, DFSchemaRef};
use datafusion_expr::expr::BinaryExpr;
-use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter,
RewriteRecursion};
+use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter};
use datafusion_expr::expr_visitor::{ExprVisitable, ExpressionVisitor,
Recursion};
use datafusion_expr::{
and, col,
@@ -84,7 +84,7 @@ fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs:
Vec<&'a Expr>) -> Vec<&
///
/// # Example
/// ```
-/// # use datafusion_expr::{col, lit, Operator};
+/// # use datafusion_expr::{col, lit};
/// # use datafusion_optimizer::utils::split_conjunction_owned;
/// // a=1 AND b=2
/// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2)));
@@ -96,23 +96,23 @@ fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs:
Vec<&'a Expr>) -> Vec<&
/// ];
///
/// // use split_conjunction_owned to split them
-/// assert_eq!(split_conjunction_owned(expr, Operator::And), split);
+/// assert_eq!(split_conjunction_owned(expr), split);
/// ```
-pub fn split_conjunction_owned(expr: Expr, op: Operator) -> Vec<Expr> {
- split_conjunction_owned_impl(expr, op, vec![])
+pub fn split_conjunction_owned(expr: Expr) -> Vec<Expr> {
+ split_conjunction_owned_impl(expr, vec![])
}
-fn split_conjunction_owned_impl(
- expr: Expr,
- operator: Operator,
- mut exprs: Vec<Expr>,
-) -> Vec<Expr> {
+fn split_conjunction_owned_impl(expr: Expr, mut exprs: Vec<Expr>) -> Vec<Expr>
{
match expr {
- Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => {
- let exprs = split_conjunction_owned_impl(*left, Operator::And,
exprs);
- split_conjunction_owned_impl(*right, Operator::And, exprs)
+ Expr::BinaryExpr(BinaryExpr {
+ right,
+ op: Operator::And,
+ left,
+ }) => {
+ let exprs = split_conjunction_owned_impl(*left, exprs);
+ split_conjunction_owned_impl(*right, exprs)
}
- Expr::Alias(expr, _) => split_conjunction_owned_impl(*expr,
Operator::And, exprs),
+ Expr::Alias(expr, _) => split_conjunction_owned_impl(*expr, exprs),
other => {
exprs.push(other);
exprs
@@ -120,149 +120,6 @@ fn split_conjunction_owned_impl(
}
}
-/// Converts an expression to conjunctive normal form (CNF).
-///
-/// The following expression is in CNF:
-/// `(a OR b) AND (c OR d)`
-/// The following is not in CNF:
-/// `(a AND b) OR c`.
-/// But could be rewrite to a CNF expression:
-/// `(a OR c) AND (b OR c)`.
-///
-/// # Example
-/// ```
-/// # use datafusion_expr::{col, lit};
-/// # use datafusion_expr::expr_rewriter::ExprRewritable;
-/// # use datafusion_optimizer::utils::CnfHelper;
-/// // (a=1 AND b=2)OR c = 3
-/// let expr1 = col("a").eq(lit(1)).and(col("b").eq(lit(2)));
-/// let expr2 = col("c").eq(lit(3));
-/// let expr = expr1.or(expr2);
-///
-/// //(a=1 or c=3)AND(b=2 or c=3)
-/// let expr1 = col("a").eq(lit(1)).or(col("c").eq(lit(3)));
-/// let expr2 = col("b").eq(lit(2)).or(col("c").eq(lit(3)));
-/// let expect = expr1.and(expr2);
-/// // use split_conjunction_owned to split them
-/// assert_eq!(expr.rewrite(& mut CnfHelper::new()).unwrap(), expect);
-/// ```
-///
-pub struct CnfHelper {
- max_count: usize,
- current_count: usize,
- exprs: Vec<Expr>,
- original_expr: Option<Expr>,
-}
-
-impl CnfHelper {
- pub fn new() -> Self {
- CnfHelper {
- max_count: 50,
- current_count: 0,
- exprs: vec![],
- original_expr: None,
- }
- }
-
- pub fn new_with_max_count(max_count: usize) -> Self {
- CnfHelper {
- max_count,
- current_count: 0,
- exprs: vec![],
- original_expr: None,
- }
- }
-
- fn increment_and_check_overload(&mut self) -> bool {
- self.current_count += 1;
- self.current_count >= self.max_count
- }
-}
-
-impl ExprRewriter for CnfHelper {
- fn pre_visit(&mut self, expr: &Expr) -> Result<RewriteRecursion> {
- let is_root = self.original_expr.is_none();
- if is_root {
- self.original_expr = Some(expr.clone());
- }
- match expr {
- Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
- match op {
- Operator::And => {
- if self.increment_and_check_overload() {
- return Ok(RewriteRecursion::Mutate);
- }
- }
- // (a AND b) OR (c AND d) = (a OR b) AND (a OR c) AND (b
OR c) AND (b OR d)
- Operator::Or => {
- let left_and_split =
- split_conjunction_owned(*left.clone(),
Operator::And);
- let right_and_split =
- split_conjunction_owned(*right.clone(),
Operator::And);
- // Avoid create to much Expr like in tpch q19.
- let lc = split_conjunction_owned(*left.clone(),
Operator::Or)
- .into_iter()
- .flat_map(|e| split_conjunction_owned(e,
Operator::And))
- .count();
- let rc = split_conjunction_owned(*right.clone(),
Operator::Or)
- .into_iter()
- .flat_map(|e| split_conjunction_owned(e,
Operator::And))
- .count();
- self.current_count += lc * rc - 1;
- if self.increment_and_check_overload() {
- return Ok(RewriteRecursion::Mutate);
- }
- left_and_split.iter().for_each(|l| {
- right_and_split.iter().for_each(|r| {
- self.exprs.push(Expr::BinaryExpr(BinaryExpr {
- left: Box::new(l.clone()),
- op: Operator::Or,
- right: Box::new(r.clone()),
- }))
- })
- });
- return Ok(RewriteRecursion::Mutate);
- }
- _ => {
- if self.increment_and_check_overload() {
- return Ok(RewriteRecursion::Mutate);
- }
- self.exprs.push(expr.clone());
- return Ok(RewriteRecursion::Stop);
- }
- }
- }
- other => {
- if self.increment_and_check_overload() {
- return Ok(RewriteRecursion::Mutate);
- }
- self.exprs.push(other.clone());
- return Ok(RewriteRecursion::Stop);
- }
- }
- if is_root {
- Ok(RewriteRecursion::Continue)
- } else {
- Ok(RewriteRecursion::Skip)
- }
- }
-
- fn mutate(&mut self, _expr: Expr) -> Result<Expr> {
- if self.current_count >= self.max_count {
- Ok(self.original_expr.as_ref().unwrap().clone())
- } else {
- Ok(conjunction(self.exprs.clone())
- .unwrap_or_else(||
self.original_expr.as_ref().unwrap().clone()))
- }
- }
-}
-
-impl Default for CnfHelper {
- fn default() -> Self {
- Self::new()
- }
-}
-
/// Combines an array of filter expressions into a single filter
/// expression consisting of the input filter expressions joined with
/// logical AND.
@@ -612,7 +469,7 @@ mod tests {
use super::*;
use arrow::datatypes::DataType;
use datafusion_common::Column;
- use datafusion_expr::{col, lit, or, utils::expr_to_columns};
+ use datafusion_expr::{col, lit, utils::expr_to_columns};
use std::collections::HashSet;
use std::ops::Add;
@@ -653,16 +510,13 @@ mod tests {
#[test]
fn test_split_conjunction_owned() {
let expr = col("a");
- assert_eq!(
- split_conjunction_owned(expr.clone(), Operator::And),
- vec![expr]
- );
+ assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]);
}
#[test]
fn test_split_conjunction_owned_two() {
assert_eq!(
- split_conjunction_owned(col("a").eq(lit(5)).and(col("b")),
Operator::And),
+ split_conjunction_owned(col("a").eq(lit(5)).and(col("b"))),
vec![col("a").eq(lit(5)), col("b")]
);
}
@@ -670,10 +524,7 @@ mod tests {
#[test]
fn test_split_conjunction_owned_alias() {
assert_eq!(
- split_conjunction_owned(
- col("a").eq(lit(5)).and(col("b").alias("the_alias")),
- Operator::And
- ),
+
split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias"))),
vec![
col("a").eq(lit(5)),
// no alias on b
@@ -719,10 +570,7 @@ mod tests {
#[test]
fn test_split_conjunction_owned_or() {
let expr = col("a").eq(lit(5)).or(col("b"));
- assert_eq!(
- split_conjunction_owned(expr.clone(), Operator::And),
- vec![expr]
- );
+ assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]);
}
#[test]
@@ -815,84 +663,4 @@ mod tests {
"mismatch rewriting expr_from: {expr_from} to {rewrite_to}"
)
}
-
- #[test]
- fn test_rewrite_cnf() {
- let a_1 = col("a").eq(lit(1i64));
- let a_2 = col("a").eq(lit(2i64));
-
- let b_1 = col("b").eq(lit(1i64));
- let b_2 = col("b").eq(lit(2i64));
-
- // Test rewrite on a1_and_b2 and a2_and_b1 -> not change
- let mut helper = CnfHelper::new();
- let expr1 = and(a_1.clone(), b_2.clone());
- let expect = expr1.clone();
- let res = expr1.rewrite(&mut helper).unwrap();
- assert_eq!(expect, res);
-
- // Test rewrite on a1_and_b2 and a2_and_b1 -> (((a1 and b2) and a2)
and b1)
- let mut helper = CnfHelper::new();
- let expr1 = and(and(a_1.clone(), b_2.clone()), and(a_2.clone(),
b_1.clone()));
- let expect = and(a_1.clone(), b_2.clone())
- .and(a_2.clone())
- .and(b_1.clone());
- let res = expr1.rewrite(&mut helper).unwrap();
- assert_eq!(expect, res);
-
- // Test rewrite on a1_or_b2 -> not change
- let mut helper = CnfHelper::new();
- let expr1 = or(a_1.clone(), b_2.clone());
- let expect = expr1.clone();
- let res = expr1.rewrite(&mut helper).unwrap();
- assert_eq!(expect, res);
-
- // Test rewrite on a1_and_b2 or a2_and_b1 -> a1_or_a2 and a1_or_b1
and b2_or_a2 and b2_or_b1
- let mut helper = CnfHelper::new();
- let expr1 = or(and(a_1.clone(), b_2.clone()), and(a_2.clone(),
b_1.clone()));
- let a1_or_a2 = or(a_1.clone(), a_2.clone());
- let a1_or_b1 = or(a_1.clone(), b_1.clone());
- let b2_or_a2 = or(b_2.clone(), a_2.clone());
- let b2_or_b1 = or(b_2.clone(), b_1.clone());
- let expect = and(a1_or_a2, a1_or_b1).and(b2_or_a2).and(b2_or_b1);
- let res = expr1.rewrite(&mut helper).unwrap();
- assert_eq!(expect, res);
-
- // Test rewrite on a1_or_b2 or a2_and_b1 -> ( a1_or_a2 or a2 ) and
(a1_or_a2 or b1)
- let mut helper = CnfHelper::new();
- let a1_or_b2 = or(a_1.clone(), b_2.clone());
- let expr1 = or(or(a_1.clone(), b_2.clone()), and(a_2.clone(),
b_1.clone()));
- let expect = or(a1_or_b2.clone(), a_2.clone()).and(or(a1_or_b2,
b_1.clone()));
- let res = expr1.rewrite(&mut helper).unwrap();
- assert_eq!(expect, res);
-
- // Test rewrite on a1_or_b2 or a2_or_b1 -> not change
- let mut helper = CnfHelper::new();
- let expr1 = or(or(a_1, b_2), or(a_2, b_1));
- let expect = expr1.clone();
- let res = expr1.rewrite(&mut helper).unwrap();
- assert_eq!(expect, res);
- }
-
- #[test]
- fn test_rewrite_cnf_overflow() {
- // in this situation:
- // AND = (a=1 and b=2)
- // rewrite (AND * 10) or (AND * 10), it will produce 10 * 10 = 100
(a=1 or b=2)
- // which cause size expansion.
-
- let mut expr1 = col("test1").eq(lit(1i64));
- let expr2 = col("test2").eq(lit(2i64));
-
- for _i in 0..9 {
- expr1 = expr1.clone().and(expr2.clone());
- }
- let expr3 = expr1.clone();
- let expr = or(expr1, expr3);
- let mut helper = CnfHelper::new();
- let res = expr.clone().rewrite(&mut helper).unwrap();
- assert_eq!(100, helper.current_count);
- assert_eq!(res, expr);
- assert!(helper.current_count >= helper.max_count);
- }
}