This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 326117c1c fix: add one more projection to recover output schema (#4733)
326117c1c is described below

commit 326117c1cff16e708e6436795a35ba0cca888704
Author: Ruihang Xia <[email protected]>
AuthorDate: Thu Dec 29 05:02:40 2022 +0800

    fix: add one more projection to recover output schema (#4733)
    
    * fix: do not create projection plan manually
    
    Signed-off-by: Ruihang Xia <[email protected]>
    
    * add another projection to change schema back
    
    Signed-off-by: Ruihang Xia <[email protected]>
    
    * conditional recover and add document
    
    Signed-off-by: Ruihang Xia <[email protected]>
    
    * clean up
    
    Signed-off-by: Ruihang Xia <[email protected]>
    
    * check schema after all
    
    Signed-off-by: Ruihang Xia <[email protected]>
    
    * Update datafusion/optimizer/src/common_subexpr_eliminate.rs
    
    Co-authored-by: Andrew Lamb <[email protected]>
    
    * fix format
    
    Signed-off-by: Ruihang Xia <[email protected]>
    
    Signed-off-by: Ruihang Xia <[email protected]>
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/core/tests/sql/predicates.rs            |   4 +-
 .../optimizer/src/common_subexpr_eliminate.rs      | 166 ++++++++++++++-------
 2 files changed, 116 insertions(+), 54 deletions(-)

diff --git a/datafusion/core/tests/sql/predicates.rs 
b/datafusion/core/tests/sql/predicates.rs
index 94d3e0614..d56f95e55 100644
--- a/datafusion/core/tests/sql/predicates.rs
+++ b/datafusion/core/tests/sql/predicates.rs
@@ -591,8 +591,8 @@ async fn multiple_or_predicates() -> Result<()> {
         "    Filter: part.p_brand = Utf8(\"Brand#12\") AND lineitem.l_quantity 
>= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= 
Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = 
Utf8(\"Brand#23\") AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND 
lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) 
OR part.p_brand = Utf8(\"Brand#34\") AND lineitem.l_quantity >= 
Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decim [...]
         "      Inner Join: lineitem.l_partkey = part.p_partkey 
[l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, 
p_size:Int32]",
         "        Projection: lineitem.l_partkey, lineitem.l_quantity 
[l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
-        "          Filter: (lineitem.l_quantity >= 
Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity AND 
lineitem.l_quantity <= 
Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity OR 
lineitem.l_quantity >= 
Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantity AND 
lineitem.l_quantity <= 
Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity OR 
lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Som [...]
-        "            Projection: lineitem.l_quantity <= 
Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= 
Decimal128(Some(2000),15,2) AS lineitem.l_quantity <= 
Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= 
Decimal128(Some(2000),15,2)lineitem.l_quantity <= 
Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity
 <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity, 
lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR [...]
+        "          Filter: (lineitem.l_quantity >= 
Decimal128(Some(100),15,2)Decimal128(Some(100),15,2)lineitem.l_quantity AND 
lineitem.l_quantity <= 
Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity OR 
lineitem.l_quantity >= 
Decimal128(Some(1000),15,2)Decimal128(Some(1000),15,2)lineitem.l_quantity AND 
lineitem.l_quantity <= 
Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantity OR 
lineitem.l_quantity >= Decimal128(Some(2000),15,2)Decimal128(Som [...]
+        "            Projection: lineitem.l_quantity <= 
Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= 
Decimal128(Some(2000),15,2) AS lineitem.l_quantity <= 
Decimal128(Some(1100),15,2) OR lineitem.l_quantity <= 
Decimal128(Some(2000),15,2)lineitem.l_quantity <= 
Decimal128(Some(2000),15,2)Decimal128(Some(2000),15,2)lineitem.l_quantitylineitem.l_quantity
 <= Decimal128(Some(1100),15,2)Decimal128(Some(1100),15,2)lineitem.l_quantity, 
lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR [...]
         "              TableScan: lineitem projection=[l_partkey, l_quantity], 
partial_filters=[lineitem.l_quantity >= Decimal128(Some(100),15,2) OR 
lineitem.l_quantity >= Decimal128(Some(1000),15,2) OR lineitem.l_quantity >= 
Decimal128(Some(2000),15,2), lineitem.l_quantity >= Decimal128(Some(100),15,2) 
OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) OR lineitem.l_quantity <= 
Decimal128(Some(3000),15,2), lineitem.l_quantity >= Decimal128(Some(100),15,2) 
OR lineitem.l_quantity <= De [...]
         "        Filter: (part.p_brand = Utf8(\"Brand#12\") AND part.p_size <= 
Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND part.p_size <= Int32(10) OR 
part.p_brand = Utf8(\"Brand#34\") AND part.p_size <= Int32(15)) AND part.p_size 
>= Int32(1) [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
         "          TableScan: part projection=[p_partkey, p_brand, p_size], 
partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8(\"Brand#12\") AND 
part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND part.p_size <= 
Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND part.p_size <= Int32(15)] 
[p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs 
b/datafusion/optimizer/src/common_subexpr_eliminate.rs
index a8c9f5d86..c8bddcfbf 100644
--- a/datafusion/optimizer/src/common_subexpr_eliminate.rs
+++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs
@@ -86,7 +86,7 @@ impl CommonSubexprEliminate {
             .try_optimize(input, config)?
             .unwrap_or_else(|| input.clone());
         if !affected_id.is_empty() {
-            new_input = build_project_plan(new_input, affected_id, expr_set)?;
+            new_input = build_common_expr_project_plan(new_input, affected_id, 
expr_set)?;
         }
 
         Ok((rewrite_exprs, new_input))
@@ -101,7 +101,8 @@ impl OptimizerRule for CommonSubexprEliminate {
     ) -> Result<Option<LogicalPlan>> {
         let mut expr_set = ExprSet::new();
 
-        match plan {
+        let original_schema = plan.schema().clone();
+        let mut optimized_plan = match plan {
             LogicalPlan::Projection(Projection {
                 expr,
                 input,
@@ -114,13 +115,11 @@ impl OptimizerRule for CommonSubexprEliminate {
                 let (mut new_expr, new_input) =
                     self.rewrite_expr(&[expr], &[&arrays], input, &mut 
expr_set, config)?;
 
-                Ok(Some(LogicalPlan::Projection(
-                    Projection::try_new_with_schema(
-                        pop_expr(&mut new_expr)?,
-                        Arc::new(new_input),
-                        schema.clone(),
-                    )?,
-                )))
+                LogicalPlan::Projection(Projection::try_new_with_schema(
+                    pop_expr(&mut new_expr)?,
+                    Arc::new(new_input),
+                    schema.clone(),
+                )?)
             }
             LogicalPlan::Filter(filter) => {
                 let input = &filter.input;
@@ -143,14 +142,11 @@ impl OptimizerRule for CommonSubexprEliminate {
                 )?;
 
                 if let Some(predicate) = pop_expr(&mut new_expr)?.pop() {
-                    Ok(Some(LogicalPlan::Filter(Filter::try_new(
-                        predicate,
-                        Arc::new(new_input),
-                    )?)))
+                    LogicalPlan::Filter(Filter::try_new(predicate, 
Arc::new(new_input))?)
                 } else {
-                    Err(DataFusionError::Internal(
+                    return Err(DataFusionError::Internal(
                         "Failed to pop predicate expr".to_string(),
-                    ))
+                    ));
                 }
             }
             LogicalPlan::Window(Window {
@@ -169,11 +165,11 @@ impl OptimizerRule for CommonSubexprEliminate {
                     config,
                 )?;
 
-                Ok(Some(LogicalPlan::Window(Window {
+                LogicalPlan::Window(Window {
                     input: Arc::new(new_input),
                     window_expr: pop_expr(&mut new_expr)?,
                     schema: schema.clone(),
-                })))
+                })
             }
             LogicalPlan::Aggregate(Aggregate {
                 group_expr,
@@ -198,14 +194,12 @@ impl OptimizerRule for CommonSubexprEliminate {
                 let new_aggr_expr = pop_expr(&mut new_expr)?;
                 let new_group_expr = pop_expr(&mut new_expr)?;
 
-                Ok(Some(LogicalPlan::Aggregate(
-                    Aggregate::try_new_with_schema(
-                        Arc::new(new_input),
-                        new_group_expr,
-                        new_aggr_expr,
-                        schema.clone(),
-                    )?,
-                )))
+                LogicalPlan::Aggregate(Aggregate::try_new_with_schema(
+                    Arc::new(new_input),
+                    new_group_expr,
+                    new_aggr_expr,
+                    schema.clone(),
+                )?)
             }
             LogicalPlan::Sort(Sort { expr, input, fetch }) => {
                 let input_schema = Arc::clone(input.schema());
@@ -214,11 +208,11 @@ impl OptimizerRule for CommonSubexprEliminate {
                 let (mut new_expr, new_input) =
                     self.rewrite_expr(&[expr], &[&arrays], input, &mut 
expr_set, config)?;
 
-                Ok(Some(LogicalPlan::Sort(Sort {
+                LogicalPlan::Sort(Sort {
                     expr: pop_expr(&mut new_expr)?,
                     input: Arc::new(new_input),
                     fetch: *fetch,
-                })))
+                })
             }
             LogicalPlan::Join(_)
             | LogicalPlan::CrossJoin(_)
@@ -244,9 +238,16 @@ impl OptimizerRule for CommonSubexprEliminate {
             | LogicalPlan::Extension(_)
             | LogicalPlan::Prepare(_) => {
                 // apply the optimization to all inputs of the plan
-                Ok(Some(utils::optimize_children(self, plan, config)?))
+                utils::optimize_children(self, plan, config)?
             }
+        };
+
+        // add an additional projection if the output schema changed.
+        if optimized_plan.schema() != &original_schema {
+            optimized_plan = build_recover_project_plan(&original_schema, 
optimized_plan);
         }
+
+        Ok(Some(optimized_plan))
     }
 
     fn name(&self) -> &str {
@@ -289,13 +290,12 @@ fn to_arrays(
 }
 
 /// Build the "intermediate" projection plan that evaluates the extracted 
common expressions.
-fn build_project_plan(
+fn build_common_expr_project_plan(
     input: LogicalPlan,
     affected_id: BTreeSet<Identifier>,
     expr_set: &ExprSet,
 ) -> Result<LogicalPlan> {
     let mut project_exprs = vec![];
-    let mut fields = vec![];
     let mut fields_set = BTreeSet::new();
 
     for id in affected_id {
@@ -304,7 +304,6 @@ fn build_project_plan(
                 // todo: check `nullable`
                 let field = DFField::new(None, &id, data_type.clone(), true);
                 fields_set.insert(field.name().to_owned());
-                fields.push(field);
                 project_exprs.push(expr.clone().alias(&id));
             }
             _ => {
@@ -317,20 +316,32 @@ fn build_project_plan(
 
     for field in input.schema().fields() {
         if fields_set.insert(field.qualified_name()) {
-            fields.push(field.clone());
             project_exprs.push(Expr::Column(field.qualified_column()));
         }
     }
 
-    let schema = DFSchema::new_with_metadata(fields, HashMap::new())?;
-
-    Ok(LogicalPlan::Projection(Projection::try_new_with_schema(
+    Ok(LogicalPlan::Projection(Projection::try_new(
         project_exprs,
         Arc::new(input),
-        Arc::new(schema),
     )?))
 }
 
+/// Build the projection plan to eliminate unexpected columns produced by
+/// the "intermediate" projection plan built in 
[build_common_expr_project_plan].
+///
+/// This is for those plans who don't keep its own output schema like `Filter` 
or `Sort`.
+fn build_recover_project_plan(schema: &DFSchema, input: LogicalPlan) -> 
LogicalPlan {
+    let col_exprs = schema
+        .fields()
+        .iter()
+        .map(|field| Expr::Column(field.qualified_column()))
+        .collect();
+    LogicalPlan::Projection(
+        Projection::try_new(col_exprs, Arc::new(input))
+            .expect("Cannot build projection plan from an invalid schema"),
+    )
+}
+
 /// Go through an expression tree and generate identifier.
 ///
 /// An identifier contains information of the expression itself and its 
sub-expression.
@@ -567,6 +578,7 @@ mod test {
 
     use arrow::datatypes::{Field, Schema};
 
+    use datafusion_common::DFSchema;
     use datafusion_expr::logical_plan::{table_scan, JoinType};
     use datafusion_expr::{
         avg, binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, 
sum,
@@ -754,7 +766,6 @@ mod test {
         \n    TableScan: test";
 
         assert_optimized_plan_eq(expected, &plan);
-
         Ok(())
     }
 
@@ -762,16 +773,30 @@ mod test {
     fn redundant_project_fields() {
         let table_scan = test_table_scan().unwrap();
         let affected_id: BTreeSet<Identifier> =
-            ["c+a".to_string(), "d+a".to_string()].into_iter().collect();
-        let expr_set = [
+            ["c+a".to_string(), "b+a".to_string()].into_iter().collect();
+        let expr_set_1 = [
+            (
+                "c+a".to_string(),
+                (col("c") + col("a"), 1, DataType::UInt32),
+            ),
+            (
+                "b+a".to_string(),
+                (col("b") + col("a"), 1, DataType::UInt32),
+            ),
+        ]
+        .into_iter()
+        .collect();
+        let expr_set_2 = [
             ("c+a".to_string(), (col("c+a"), 1, DataType::UInt32)),
-            ("d+a".to_string(), (col("d+a"), 1, DataType::UInt32)),
+            ("b+a".to_string(), (col("b+a"), 1, DataType::UInt32)),
         ]
         .into_iter()
         .collect();
         let project =
-            build_project_plan(table_scan, affected_id.clone(), 
&expr_set).unwrap();
-        let project_2 = build_project_plan(project, affected_id, 
&expr_set).unwrap();
+            build_common_expr_project_plan(table_scan, affected_id.clone(), 
&expr_set_1)
+                .unwrap();
+        let project_2 =
+            build_common_expr_project_plan(project, affected_id, 
&expr_set_2).unwrap();
 
         let mut field_set = BTreeSet::new();
         for field in project_2.schema().fields() {
@@ -789,15 +814,38 @@ mod test {
             .build()
             .unwrap();
         let affected_id: BTreeSet<Identifier> =
-            ["c+a".to_string(), "d+a".to_string()].into_iter().collect();
-        let expr_set = [
-            ("c+a".to_string(), (col("c+a"), 1, DataType::UInt32)),
-            ("d+a".to_string(), (col("d+a"), 1, DataType::UInt32)),
+            ["test1.c+test1.a".to_string(), "test1.b+test1.a".to_string()]
+                .into_iter()
+                .collect();
+        let expr_set_1 = [
+            (
+                "test1.c+test1.a".to_string(),
+                (col("test1.c") + col("test1.a"), 1, DataType::UInt32),
+            ),
+            (
+                "test1.b+test1.a".to_string(),
+                (col("test1.b") + col("test1.a"), 1, DataType::UInt32),
+            ),
+        ]
+        .into_iter()
+        .collect();
+        let expr_set_2 = [
+            (
+                "test1.c+test1.a".to_string(),
+                (col("test1.c+test1.a"), 1, DataType::UInt32),
+            ),
+            (
+                "test1.b+test1.a".to_string(),
+                (col("test1.b+test1.a"), 1, DataType::UInt32),
+            ),
         ]
         .into_iter()
         .collect();
-        let project = build_project_plan(join, affected_id.clone(), 
&expr_set).unwrap();
-        let project_2 = build_project_plan(project, affected_id, 
&expr_set).unwrap();
+        let project =
+            build_common_expr_project_plan(join, affected_id.clone(), 
&expr_set_1)
+                .unwrap();
+        let project_2 =
+            build_common_expr_project_plan(project, affected_id, 
&expr_set_2).unwrap();
 
         let mut field_set = BTreeSet::new();
         for field in project_2.schema().fields() {
@@ -839,10 +887,6 @@ mod test {
             .collect();
         let formatted_fields_with_datatype = 
format!("{fields_with_datatypes:#?}");
         let expected = r###"[
-    (
-        "CAST(table.a AS Int64)table.a",
-        Int64,
-    ),
     (
         "a",
         UInt64,
@@ -858,4 +902,22 @@ mod test {
 ]"###;
         assert_eq!(expected, formatted_fields_with_datatype);
     }
+
+    #[test]
+    fn filter_schema_changed() -> Result<()> {
+        let table_scan = test_table_scan()?;
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .filter(lit(1).gt(col("a")).and(lit(1).gt(col("a"))))?
+            .build()?;
+
+        let expected = "Projection: test.a, test.b, test.c\
+        \n  Filter: Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a AND 
Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a\
+        \n    Projection: Int32(1) > test.a AS Int32(1) > 
test.atest.aInt32(1), test.a, test.b, test.c\
+        \n      TableScan: test";
+
+        assert_optimized_plan_eq(expected, &plan);
+
+        Ok(())
+    }
 }

Reply via email to