This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 326117c1c fix: add one more projection to recover output schema (#4733)
326117c1c is described below
commit 326117c1cff16e708e6436795a35ba0cca888704
Author: Ruihang Xia <[email protected]>
AuthorDate: Thu Dec 29 05:02:40 2022 +0800
fix: add one more projection to recover output schema (#4733)
* fix: do not create projection plan manually
Signed-off-by: Ruihang Xia <[email protected]>
* add another projection to change schema back
Signed-off-by: Ruihang Xia <[email protected]>
* conditional recover and add document
Signed-off-by: Ruihang Xia <[email protected]>
* clean up
Signed-off-by: Ruihang Xia <[email protected]>
* check schema after all
Signed-off-by: Ruihang Xia <[email protected]>
* Update datafusion/optimizer/src/common_subexpr_eliminate.rs
Co-authored-by: Andrew Lamb <[email protected]>
* fix format
Signed-off-by: Ruihang Xia <[email protected]>
Signed-off-by: Ruihang Xia <[email protected]>
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/core/tests/sql/predicates.rs | 4 +-
.../optimizer/src/common_subexpr_eliminate.rs | 166 ++++++++++++++-------
2 files changed, 116 insertions(+), 54 deletions(-)
diff --git a/datafusion/core/tests/sql/predicates.rs
b/datafusion/core/tests/sql/predicates.rs
index 94d3e0614..d56f95e55 100644
--- a/datafusion/core/tests/sql/predicates.rs
+++ b/datafusion/core/tests/sql/predicates.rs
@@ -591,8 +591,8 @@ async fn multiple_or_predicates() -> Result<()> {
" Filter: part.p_brand = Utf8(\"Brand#12\") AND lineitem.l_quantity
>= Decimal128(Some(100),15,2) AND lineitem.l_quantity <=
Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand =
Utf8(\"Brand#23\") AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND
lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10)
OR part.p_brand = Utf8(\"Brand#34\") AND lineitem.l_quantity >=
Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decim [...]
" Inner Join: lineitem.l_partkey = part.p_partkey
[l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8,
p_size:Int32]",
" Projection: lineitem.l_partkey, lineitem.l_quantity
[l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
- " Filter: (lineitem.l_quantity >=
Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity AND
lineitem.l_quantity <=
Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity OR
lineitem.l_quantity >=
Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantity AND
lineitem.l_quantity <=
Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity OR
lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Som [...]
- " Projection: lineitem.l_quantity <=
Decimal128(Some(1100),15,2) OR lineitem.l_quantity <=
Decimal128(Some(2000),15,2) AS lineitem.l_quantity <=
Decimal128(Some(1100),15,2) OR lineitem.l_quantity <=
Decimal128(Some(2000),15,2)lineitem.l_quantity <=
Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity
<= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity,
lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR [...]
+ " Filter: (lineitem.l_quantity >=
Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity AND
lineitem.l_quantity <=
Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity OR
lineitem.l_quantity >=
Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantity AND
lineitem.l_quantity <=
Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity OR
lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Som [...]
+ " Projection: lineitem.l_quantity <=
Decimal128(Some(1100),15,2) OR lineitem.l_quantity <=
Decimal128(Some(2000),15,2) AS lineitem.l_quantity <=
Decimal128(Some(1100),15,2) OR lineitem.l_quantity <=
Decimal128(Some(2000),15,2)lineitem.l_quantity <=
Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity
<= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity,
lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR [...]
" TableScan: lineitem projection=[l_partkey, l_quantity],
partial_filters=[lineitem.l_quantity >= Decimal128(Some(100),15,2) OR
lineitem.l_quantity >= Decimal128(Some(1000),15,2) OR lineitem.l_quantity >=
Decimal128(Some(2000),15,2), lineitem.l_quantity >= Decimal128(Some(100),15,2)
OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) OR lineitem.l_quantity <=
Decimal128(Some(3000),15,2), lineitem.l_quantity >= Decimal128(Some(100),15,2)
OR lineitem.l_quantity <= De [...]
" Filter: (part.p_brand = Utf8(\"Brand#12\") AND part.p_size <=
Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND part.p_size <= Int32(10) OR
part.p_brand = Utf8(\"Brand#34\") AND part.p_size <= Int32(15)) AND part.p_size
>= Int32(1) [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" TableScan: part projection=[p_partkey, p_brand, p_size],
partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8(\"Brand#12\") AND
part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND part.p_size <=
Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND part.p_size <= Int32(15)]
[p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs
b/datafusion/optimizer/src/common_subexpr_eliminate.rs
index a8c9f5d86..c8bddcfbf 100644
--- a/datafusion/optimizer/src/common_subexpr_eliminate.rs
+++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs
@@ -86,7 +86,7 @@ impl CommonSubexprEliminate {
.try_optimize(input, config)?
.unwrap_or_else(|| input.clone());
if !affected_id.is_empty() {
- new_input = build_project_plan(new_input, affected_id, expr_set)?;
+ new_input = build_common_expr_project_plan(new_input, affected_id,
expr_set)?;
}
Ok((rewrite_exprs, new_input))
@@ -101,7 +101,8 @@ impl OptimizerRule for CommonSubexprEliminate {
) -> Result<Option<LogicalPlan>> {
let mut expr_set = ExprSet::new();
- match plan {
+ let original_schema = plan.schema().clone();
+ let mut optimized_plan = match plan {
LogicalPlan::Projection(Projection {
expr,
input,
@@ -114,13 +115,11 @@ impl OptimizerRule for CommonSubexprEliminate {
let (mut new_expr, new_input) =
self.rewrite_expr(&[expr], &[&arrays], input, &mut
expr_set, config)?;
- Ok(Some(LogicalPlan::Projection(
- Projection::try_new_with_schema(
- pop_expr(&mut new_expr)?,
- Arc::new(new_input),
- schema.clone(),
- )?,
- )))
+ LogicalPlan::Projection(Projection::try_new_with_schema(
+ pop_expr(&mut new_expr)?,
+ Arc::new(new_input),
+ schema.clone(),
+ )?)
}
LogicalPlan::Filter(filter) => {
let input = &filter.input;
@@ -143,14 +142,11 @@ impl OptimizerRule for CommonSubexprEliminate {
)?;
if let Some(predicate) = pop_expr(&mut new_expr)?.pop() {
- Ok(Some(LogicalPlan::Filter(Filter::try_new(
- predicate,
- Arc::new(new_input),
- )?)))
+ LogicalPlan::Filter(Filter::try_new(predicate,
Arc::new(new_input))?)
} else {
- Err(DataFusionError::Internal(
+ return Err(DataFusionError::Internal(
"Failed to pop predicate expr".to_string(),
- ))
+ ));
}
}
LogicalPlan::Window(Window {
@@ -169,11 +165,11 @@ impl OptimizerRule for CommonSubexprEliminate {
config,
)?;
- Ok(Some(LogicalPlan::Window(Window {
+ LogicalPlan::Window(Window {
input: Arc::new(new_input),
window_expr: pop_expr(&mut new_expr)?,
schema: schema.clone(),
- })))
+ })
}
LogicalPlan::Aggregate(Aggregate {
group_expr,
@@ -198,14 +194,12 @@ impl OptimizerRule for CommonSubexprEliminate {
let new_aggr_expr = pop_expr(&mut new_expr)?;
let new_group_expr = pop_expr(&mut new_expr)?;
- Ok(Some(LogicalPlan::Aggregate(
- Aggregate::try_new_with_schema(
- Arc::new(new_input),
- new_group_expr,
- new_aggr_expr,
- schema.clone(),
- )?,
- )))
+ LogicalPlan::Aggregate(Aggregate::try_new_with_schema(
+ Arc::new(new_input),
+ new_group_expr,
+ new_aggr_expr,
+ schema.clone(),
+ )?)
}
LogicalPlan::Sort(Sort { expr, input, fetch }) => {
let input_schema = Arc::clone(input.schema());
@@ -214,11 +208,11 @@ impl OptimizerRule for CommonSubexprEliminate {
let (mut new_expr, new_input) =
self.rewrite_expr(&[expr], &[&arrays], input, &mut
expr_set, config)?;
- Ok(Some(LogicalPlan::Sort(Sort {
+ LogicalPlan::Sort(Sort {
expr: pop_expr(&mut new_expr)?,
input: Arc::new(new_input),
fetch: *fetch,
- })))
+ })
}
LogicalPlan::Join(_)
| LogicalPlan::CrossJoin(_)
@@ -244,9 +238,16 @@ impl OptimizerRule for CommonSubexprEliminate {
| LogicalPlan::Extension(_)
| LogicalPlan::Prepare(_) => {
// apply the optimization to all inputs of the plan
- Ok(Some(utils::optimize_children(self, plan, config)?))
+ utils::optimize_children(self, plan, config)?
}
+ };
+
+ // add an additional projection if the output schema changed.
+ if optimized_plan.schema() != &original_schema {
+ optimized_plan = build_recover_project_plan(&original_schema,
optimized_plan);
}
+
+ Ok(Some(optimized_plan))
}
fn name(&self) -> &str {
@@ -289,13 +290,12 @@ fn to_arrays(
}
/// Build the "intermediate" projection plan that evaluates the extracted
common expressions.
-fn build_project_plan(
+fn build_common_expr_project_plan(
input: LogicalPlan,
affected_id: BTreeSet<Identifier>,
expr_set: &ExprSet,
) -> Result<LogicalPlan> {
let mut project_exprs = vec![];
- let mut fields = vec![];
let mut fields_set = BTreeSet::new();
for id in affected_id {
@@ -304,7 +304,6 @@ fn build_project_plan(
// todo: check `nullable`
let field = DFField::new(None, &id, data_type.clone(), true);
fields_set.insert(field.name().to_owned());
- fields.push(field);
project_exprs.push(expr.clone().alias(&id));
}
_ => {
@@ -317,20 +316,32 @@ fn build_project_plan(
for field in input.schema().fields() {
if fields_set.insert(field.qualified_name()) {
- fields.push(field.clone());
project_exprs.push(Expr::Column(field.qualified_column()));
}
}
- let schema = DFSchema::new_with_metadata(fields, HashMap::new())?;
-
- Ok(LogicalPlan::Projection(Projection::try_new_with_schema(
+ Ok(LogicalPlan::Projection(Projection::try_new(
project_exprs,
Arc::new(input),
- Arc::new(schema),
)?))
}
+/// Build the projection plan to eliminate unexpected columns produced by
+/// the "intermediate" projection plan built in
[build_common_expr_project_plan].
+///
+/// This is for those plans who don't keep its own output schema like `Filter`
or `Sort`.
+fn build_recover_project_plan(schema: &DFSchema, input: LogicalPlan) ->
LogicalPlan {
+ let col_exprs = schema
+ .fields()
+ .iter()
+ .map(|field| Expr::Column(field.qualified_column()))
+ .collect();
+ LogicalPlan::Projection(
+ Projection::try_new(col_exprs, Arc::new(input))
+ .expect("Cannot build projection plan from an invalid schema"),
+ )
+}
+
/// Go through an expression tree and generate identifier.
///
/// An identifier contains information of the expression itself and its
sub-expression.
@@ -567,6 +578,7 @@ mod test {
use arrow::datatypes::{Field, Schema};
+ use datafusion_common::DFSchema;
use datafusion_expr::logical_plan::{table_scan, JoinType};
use datafusion_expr::{
avg, binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder,
sum,
@@ -754,7 +766,6 @@ mod test {
\n TableScan: test";
assert_optimized_plan_eq(expected, &plan);
-
Ok(())
}
@@ -762,16 +773,30 @@ mod test {
fn redundant_project_fields() {
let table_scan = test_table_scan().unwrap();
let affected_id: BTreeSet<Identifier> =
- ["c+a".to_string(), "d+a".to_string()].into_iter().collect();
- let expr_set = [
+ ["c+a".to_string(), "b+a".to_string()].into_iter().collect();
+ let expr_set_1 = [
+ (
+ "c+a".to_string(),
+ (col("c") + col("a"), 1, DataType::UInt32),
+ ),
+ (
+ "b+a".to_string(),
+ (col("b") + col("a"), 1, DataType::UInt32),
+ ),
+ ]
+ .into_iter()
+ .collect();
+ let expr_set_2 = [
("c+a".to_string(), (col("c+a"), 1, DataType::UInt32)),
- ("d+a".to_string(), (col("d+a"), 1, DataType::UInt32)),
+ ("b+a".to_string(), (col("b+a"), 1, DataType::UInt32)),
]
.into_iter()
.collect();
let project =
- build_project_plan(table_scan, affected_id.clone(),
&expr_set).unwrap();
- let project_2 = build_project_plan(project, affected_id,
&expr_set).unwrap();
+ build_common_expr_project_plan(table_scan, affected_id.clone(),
&expr_set_1)
+ .unwrap();
+ let project_2 =
+ build_common_expr_project_plan(project, affected_id,
&expr_set_2).unwrap();
let mut field_set = BTreeSet::new();
for field in project_2.schema().fields() {
@@ -789,15 +814,38 @@ mod test {
.build()
.unwrap();
let affected_id: BTreeSet<Identifier> =
- ["c+a".to_string(), "d+a".to_string()].into_iter().collect();
- let expr_set = [
- ("c+a".to_string(), (col("c+a"), 1, DataType::UInt32)),
- ("d+a".to_string(), (col("d+a"), 1, DataType::UInt32)),
+ ["test1.c+test1.a".to_string(), "test1.b+test1.a".to_string()]
+ .into_iter()
+ .collect();
+ let expr_set_1 = [
+ (
+ "test1.c+test1.a".to_string(),
+ (col("test1.c") + col("test1.a"), 1, DataType::UInt32),
+ ),
+ (
+ "test1.b+test1.a".to_string(),
+ (col("test1.b") + col("test1.a"), 1, DataType::UInt32),
+ ),
+ ]
+ .into_iter()
+ .collect();
+ let expr_set_2 = [
+ (
+ "test1.c+test1.a".to_string(),
+ (col("test1.c+test1.a"), 1, DataType::UInt32),
+ ),
+ (
+ "test1.b+test1.a".to_string(),
+ (col("test1.b+test1.a"), 1, DataType::UInt32),
+ ),
]
.into_iter()
.collect();
- let project = build_project_plan(join, affected_id.clone(),
&expr_set).unwrap();
- let project_2 = build_project_plan(project, affected_id,
&expr_set).unwrap();
+ let project =
+ build_common_expr_project_plan(join, affected_id.clone(),
&expr_set_1)
+ .unwrap();
+ let project_2 =
+ build_common_expr_project_plan(project, affected_id,
&expr_set_2).unwrap();
let mut field_set = BTreeSet::new();
for field in project_2.schema().fields() {
@@ -839,10 +887,6 @@ mod test {
.collect();
let formatted_fields_with_datatype =
format!("{fields_with_datatypes:#?}");
let expected = r###"[
- (
- "CAST(table.a AS Int64)table.a",
- Int64,
- ),
(
"a",
UInt64,
@@ -858,4 +902,22 @@ mod test {
]"###;
assert_eq!(expected, formatted_fields_with_datatype);
}
+
+ #[test]
+ fn filter_schema_changed() -> Result<()> {
+ let table_scan = test_table_scan()?;
+
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .filter(lit(1).gt(col("a")).and(lit(1).gt(col("a"))))?
+ .build()?;
+
+ let expected = "Projection: test.a, test.b, test.c\
+ \n Filter: Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a AND
Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a\
+ \n Projection: Int32(1) > test.a AS Int32(1) >
test.atest.aInt32(1), test.a, test.b, test.c\
+ \n TableScan: test";
+
+ assert_optimized_plan_eq(expected, &plan);
+
+ Ok(())
+ }
}