This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 0cf563062 Fix output schema generated by CommonSubExprEliminate (#3726)
0cf563062 is described below

commit 0cf5630626ba7e3d814ef00477bcae2bc3cae9c2
Author: Alexander Spies <[email protected]>
AuthorDate: Tue Oct 11 14:14:45 2022 +0200

    Fix output schema generated by CommonSubExprEliminate (#3726)
    
    * CommonSubexprEliminate: Fix additional col schema
    
    * Use correct types in test id_array_visitor
    
    * Re-enable fall back schema for datatype resolution
    
    Fall back to the merged schema from the whole logical plan if the input
    schema was not sufficient to resolve the datatype of a sub-expression.
    This re-enables the fallback logic added in 3860cd3 (#1925).
    
    * Add comment on fall-back logic using all schemas
    
    Point out that it can likely be removed.
---
 .../optimizer/src/common_subexpr_eliminate.rs      | 183 +++++++++++++++++----
 1 file changed, 149 insertions(+), 34 deletions(-)

diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs 
b/datafusion/optimizer/src/common_subexpr_eliminate.rs
index 552c03d3d..cea5e8c46 100644
--- a/datafusion/optimizer/src/common_subexpr_eliminate.rs
+++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs
@@ -19,7 +19,7 @@
 
 use crate::{OptimizerConfig, OptimizerRule};
 use arrow::datatypes::DataType;
-use datafusion_common::{DFField, DFSchema, DataFusionError, Result};
+use datafusion_common::{DFField, DFSchema, DFSchemaRef, DataFusionError, 
Result};
 use datafusion_expr::{
     col,
     expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion},
@@ -94,7 +94,10 @@ fn optimize(
             schema,
             alias,
         }) => {
-            let arrays = to_arrays(expr, input, &mut expr_set)?;
+            let input_schema = Arc::clone(input.schema());
+            let all_schemas: Vec<DFSchemaRef> =
+                plan.all_schemas().into_iter().cloned().collect();
+            let arrays = to_arrays(expr, input_schema, all_schemas, &mut 
expr_set)?;
 
             let (mut new_expr, new_input) = rewrite_expr(
                 &[expr],
@@ -112,22 +115,18 @@ fn optimize(
             )?))
         }
         LogicalPlan::Filter(Filter { predicate, input }) => {
-            let schema = plan.schema().as_ref().clone();
-            let data_type = if let Ok(data_type) = predicate.get_type(&schema) 
{
-                data_type
-            } else {
-                // predicate type could not be resolved in schema, fall back 
to all schemas
-                let schemas = plan.all_schemas();
-                let all_schema =
-                    schemas.into_iter().fold(DFSchema::empty(), |mut lhs, rhs| 
{
-                        lhs.merge(rhs);
-                        lhs
-                    });
-                predicate.get_type(&all_schema)?
-            };
+            let input_schema = Arc::clone(input.schema());
+            let all_schemas: Vec<DFSchemaRef> =
+                plan.all_schemas().into_iter().cloned().collect();
 
             let mut id_array = vec![];
-            expr_to_identifier(predicate, &mut expr_set, &mut id_array, 
data_type)?;
+            expr_to_identifier(
+                predicate,
+                &mut expr_set,
+                &mut id_array,
+                input_schema,
+                all_schemas,
+            )?;
 
             let (mut new_expr, new_input) = rewrite_expr(
                 &[&[predicate.clone()]],
@@ -153,7 +152,11 @@ fn optimize(
             window_expr,
             schema,
         }) => {
-            let arrays = to_arrays(window_expr, input, &mut expr_set)?;
+            let input_schema = Arc::clone(input.schema());
+            let all_schemas: Vec<DFSchemaRef> =
+                plan.all_schemas().into_iter().cloned().collect();
+            let arrays =
+                to_arrays(window_expr, input_schema, all_schemas, &mut 
expr_set)?;
 
             let (mut new_expr, new_input) = rewrite_expr(
                 &[window_expr],
@@ -175,8 +178,17 @@ fn optimize(
             input,
             schema,
         }) => {
-            let group_arrays = to_arrays(group_expr, input, &mut expr_set)?;
-            let aggr_arrays = to_arrays(aggr_expr, input, &mut expr_set)?;
+            let input_schema = Arc::clone(input.schema());
+            let all_schemas: Vec<DFSchemaRef> =
+                plan.all_schemas().into_iter().cloned().collect();
+            let group_arrays = to_arrays(
+                group_expr,
+                Arc::clone(&input_schema),
+                all_schemas.clone(),
+                &mut expr_set,
+            )?;
+            let aggr_arrays =
+                to_arrays(aggr_expr, input_schema, all_schemas, &mut 
expr_set)?;
 
             let (mut new_expr, new_input) = rewrite_expr(
                 &[group_expr, aggr_expr],
@@ -197,7 +209,10 @@ fn optimize(
             )?))
         }
         LogicalPlan::Sort(Sort { expr, input, fetch }) => {
-            let arrays = to_arrays(expr, input, &mut expr_set)?;
+            let input_schema = Arc::clone(input.schema());
+            let all_schemas: Vec<DFSchemaRef> =
+                plan.all_schemas().into_iter().cloned().collect();
+            let arrays = to_arrays(expr, input_schema, all_schemas, &mut 
expr_set)?;
 
             let (mut new_expr, new_input) = rewrite_expr(
                 &[expr],
@@ -255,14 +270,20 @@ fn pop_expr(new_expr: &mut Vec<Vec<Expr>>) -> 
Result<Vec<Expr>> {
 
 fn to_arrays(
     expr: &[Expr],
-    input: &LogicalPlan,
+    input_schema: DFSchemaRef,
+    all_schemas: Vec<DFSchemaRef>,
     expr_set: &mut ExprSet,
 ) -> Result<Vec<Vec<(usize, String)>>> {
     expr.iter()
         .map(|e| {
-            let data_type = e.get_type(input.schema())?;
             let mut id_array = vec![];
-            expr_to_identifier(e, expr_set, &mut id_array, data_type)?;
+            expr_to_identifier(
+                e,
+                expr_set,
+                &mut id_array,
+                Arc::clone(&input_schema),
+                all_schemas.clone(),
+            )?;
 
             Ok(id_array)
         })
@@ -370,7 +391,15 @@ struct ExprIdentifierVisitor<'a> {
     expr_set: &'a mut ExprSet,
     /// series number (usize) and identifier.
     id_array: &'a mut Vec<(usize, Identifier)>,
-    data_type: DataType,
+    /// input schema for the node that we're optimizing, so we can determine 
the correct datatype
+    /// for each subexpression
+    input_schema: DFSchemaRef,
+    /// all schemas in the logical plan, as a fall back if we cannot resolve 
an expression type
+    /// from the input schema alone
+    // This fallback should never be necessary as the expression datatype 
should always be
+    // resolvable from the input schema of the node that's being optimized.
+    // todo: This can likely be removed if we are sure it's safe to do so.
+    all_schemas: Vec<DFSchemaRef>,
 
     // inner states
     visit_stack: Vec<VisitRecord>,
@@ -448,7 +477,25 @@ impl ExpressionVisitor for ExprIdentifierVisitor<'_> {
 
         self.id_array[idx] = (self.series_number, desc.clone());
         self.visit_stack.push(VisitRecord::ExprItem(desc.clone()));
-        let data_type = self.data_type.clone();
+
+        let data_type = if let Ok(data_type) = 
expr.get_type(&self.input_schema) {
+            data_type
+        } else {
+            // Expression type could not be resolved in schema, fall back to 
all schemas.
+            //
+            // This fallback should never be necessary as the expression 
datatype should always be
+            // resolvable from the input schema of the node that's being 
optimized.
+            // todo: This else-branch can likely be removed if we are sure 
it's safe to do so.
+            let merged_schema =
+                self.all_schemas
+                    .iter()
+                    .fold(DFSchema::empty(), |mut lhs, rhs| {
+                        lhs.merge(rhs);
+                        lhs
+                    });
+            expr.get_type(&merged_schema)?
+        };
+
         self.expr_set
             .entry(desc)
             .or_insert_with(|| (expr.clone(), 0, data_type))
@@ -462,12 +509,14 @@ fn expr_to_identifier(
     expr: &Expr,
     expr_set: &mut ExprSet,
     id_array: &mut Vec<(usize, Identifier)>,
-    data_type: DataType,
+    input_schema: DFSchemaRef,
+    all_schemas: Vec<DFSchemaRef>,
 ) -> Result<()> {
     expr.accept(ExprIdentifierVisitor {
         expr_set,
         id_array,
-        data_type,
+        input_schema,
+        all_schemas,
         visit_stack: vec![],
         node_count: 0,
         series_number: 0,
@@ -577,7 +626,8 @@ fn replace_common_expr(
 mod test {
     use super::*;
     use crate::test::*;
-    use datafusion_expr::logical_plan::JoinType;
+    use arrow::datatypes::{Field, Schema};
+    use datafusion_expr::logical_plan::{table_scan, JoinType};
     use datafusion_expr::{
         avg, binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, 
sum,
         Operator,
@@ -597,7 +647,7 @@ mod test {
     fn id_array_visitor() -> Result<()> {
         let expr = binary_expr(
             binary_expr(
-                sum(binary_expr(col("a"), Operator::Plus, lit("1"))),
+                sum(binary_expr(col("a"), Operator::Plus, lit(1))),
                 Operator::Minus,
                 avg(col("c")),
             ),
@@ -605,14 +655,28 @@ mod test {
             lit(2),
         );
 
+        let schema = Arc::new(DFSchema::new_with_metadata(
+            vec![
+                DFField::new(None, "a", DataType::Int64, false),
+                DFField::new(None, "c", DataType::Int64, false),
+            ],
+            Default::default(),
+        )?);
+
         let mut id_array = vec![];
-        expr_to_identifier(&expr, &mut HashMap::new(), &mut id_array, 
DataType::Int64)?;
+        expr_to_identifier(
+            &expr,
+            &mut HashMap::new(),
+            &mut id_array,
+            Arc::clone(&schema),
+            vec![schema],
+        )?;
 
         let expected = vec![
-            (9, "SUM(a + Utf8(\"1\")) - AVG(c) * Int32(2)Int32(2)SUM(a + 
Utf8(\"1\")) - AVG(c)AVG(c)cSUM(a + Utf8(\"1\"))a + Utf8(\"1\")Utf8(\"1\")a"),
-            (7, "SUM(a + Utf8(\"1\")) - AVG(c)AVG(c)cSUM(a + Utf8(\"1\"))a + 
Utf8(\"1\")Utf8(\"1\")a"),
-            (4, "SUM(a + Utf8(\"1\"))a + Utf8(\"1\")Utf8(\"1\")a"),
-            (3, "a + Utf8(\"1\")Utf8(\"1\")a"),
+            (9, "SUM(a + Int32(1)) - AVG(c) * Int32(2)Int32(2)SUM(a + 
Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"),
+            (7, "SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + 
Int32(1)Int32(1)a"),
+            (4, "SUM(a + Int32(1))a + Int32(1)Int32(1)a"),
+            (3, "a + Int32(1)Int32(1)a"),
             (1, ""),
             (2, ""),
             (6, "AVG(c)c"),
@@ -796,4 +860,55 @@ mod test {
             assert!(field_set.insert(field.qualified_name()));
         }
     }
+
+    #[test]
+    fn eliminated_subexpr_datatype() {
+        use datafusion_expr::cast;
+
+        let schema = Schema::new(vec![
+            Field::new("a", DataType::UInt64, false),
+            Field::new("b", DataType::UInt64, false),
+            Field::new("c", DataType::UInt64, false),
+        ]);
+
+        let plan = table_scan(Some("table"), &schema, None)
+            .unwrap()
+            .filter(
+                cast(col("a"), DataType::Int64)
+                    .lt(lit(1_i64))
+                    .and(cast(col("a"), DataType::Int64).not_eq(lit(1_i64))),
+            )
+            .unwrap()
+            .build()
+            .unwrap();
+        let rule = CommonSubexprEliminate {};
+        let optimized_plan = rule.optimize(&plan, &mut 
OptimizerConfig::new()).unwrap();
+
+        let schema = optimized_plan.schema();
+        let fields_with_datatypes: Vec<_> = schema
+            .fields()
+            .iter()
+            .map(|field| (field.name(), field.data_type()))
+            .collect();
+        let formatted_fields_with_datatype = 
format!("{fields_with_datatypes:#?}");
+        let expected = r###"[
+    (
+        "CAST(table.a AS Int64)table.a",
+        Int64,
+    ),
+    (
+        "a",
+        UInt64,
+    ),
+    (
+        "b",
+        UInt64,
+    ),
+    (
+        "c",
+        UInt64,
+    ),
+]"###;
+        assert_eq!(expected, formatted_fields_with_datatype);
+    }
 }

Reply via email to