alamb commented on a change in pull request #7880:
URL: https://github.com/apache/arrow/pull/7880#discussion_r470753855



##########
File path: rust/datafusion/src/optimizer/type_coercion.rs
##########
@@ -43,138 +45,77 @@ impl<'a> TypeCoercionRule<'a> {
         Self { scalar_functions }
     }
 
-    /// Rewrite an expression list to include explicit CAST operations when 
required
-    fn rewrite_expr_list(&self, expr: &[Expr], schema: &Schema) -> 
Result<Vec<Expr>> {
-        Ok(expr
+    /// Rewrite an expression to include explicit CAST operations when required
+    fn rewrite_expr(&self, expr: &Expr, schema: &Schema) -> Result<Expr> {
+        let expressions = utils::expr_expressions(expr)?;
+
+        // recurse of the re-write
+        let mut expressions = expressions
             .iter()
             .map(|e| self.rewrite_expr(e, schema))
-            .collect::<Result<Vec<_>>>()?)
-    }
+            .collect::<Result<Vec<_>>>()?;
 
-    /// Rewrite an expression to include explicit CAST operations when required
-    fn rewrite_expr(&self, expr: &Expr, schema: &Schema) -> Result<Expr> {
+        // modify `expressions` by introducing casts when necessary
         match expr {
-            Expr::BinaryExpr { left, op, right } => {
-                let left = self.rewrite_expr(left, schema)?;
-                let right = self.rewrite_expr(right, schema)?;
-                let left_type = left.get_type(schema)?;
-                let right_type = right.get_type(schema)?;
-                if left_type == right_type {
-                    Ok(Expr::BinaryExpr {
-                        left: Box::new(left),
-                        op: op.clone(),
-                        right: Box::new(right),
-                    })
-                } else {
+            Expr::BinaryExpr { .. } => {
+                let left_type = expressions[0].get_type(schema)?;
+                let right_type = expressions[1].get_type(schema)?;
+                if left_type != right_type {
                     let super_type = utils::get_supertype(&left_type, 
&right_type)?;
-                    Ok(Expr::BinaryExpr {
-                        left: Box::new(left.cast_to(&super_type, schema)?),
-                        op: op.clone(),
-                        right: Box::new(right.cast_to(&super_type, schema)?),
-                    })
+
+                    expressions[0] = expressions[0].cast_to(&super_type, 
schema)?;
+                    expressions[1] = expressions[1].cast_to(&super_type, 
schema)?;
                 }
             }
-            Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(self.rewrite_expr(e, 
schema)?))),
-            Expr::IsNotNull(e) => {
-                Ok(Expr::IsNotNull(Box::new(self.rewrite_expr(e, schema)?)))
-            }
-            Expr::ScalarFunction {
-                name,
-                args,
-                return_type,
-            } => {
+            Expr::ScalarFunction { name, .. } => {
                 // cast the inputs of scalar functions to the appropriate type 
where possible
                 match self.scalar_functions.get(name) {
                     Some(func_meta) => {
-                        let mut func_args = Vec::with_capacity(args.len());
-                        for i in 0..args.len() {
+                        for i in 0..expressions.len() {
                             let field = &func_meta.args[i];
-                            let expr = self.rewrite_expr(&args[i], schema)?;
-                            let actual_type = expr.get_type(schema)?;
+                            let actual_type = expressions[i].get_type(schema)?;
                             let required_type = field.data_type();
-                            if &actual_type == required_type {
-                                func_args.push(expr)
-                            } else {
+                            if &actual_type != required_type {
                                 let super_type =
                                     utils::get_supertype(&actual_type, 
required_type)?;
-                                func_args.push(expr.cast_to(&super_type, 
schema)?);
-                            }
+                                expressions[i] =
+                                    expressions[i].cast_to(&super_type, 
schema)?
+                            };
                         }
-
-                        Ok(Expr::ScalarFunction {
-                            name: name.clone(),
-                            args: func_args,
-                            return_type: return_type.clone(),
-                        })
                     }
-                    _ => Err(ExecutionError::General(format!(
-                        "Invalid scalar function {}",
-                        name
-                    ))),
+                    _ => {
+                        return Err(ExecutionError::General(format!(
+                            "Invalid scalar function {}",
+                            name
+                        )))
+                    }
                 }
             }
-            Expr::AggregateFunction {
-                name,
-                args,
-                return_type,
-            } => Ok(Expr::AggregateFunction {
-                name: name.clone(),
-                args: args
-                    .iter()
-                    .map(|a| self.rewrite_expr(a, schema))
-                    .collect::<Result<Vec<_>>>()?,
-                return_type: return_type.clone(),
-            }),
-            Expr::Cast { .. } => Ok(expr.clone()),
-            Expr::Column(_) => Ok(expr.clone()),
-            Expr::Alias(expr, alias) => Ok(Expr::Alias(
-                Box::new(self.rewrite_expr(expr, schema)?),
-                alias.to_owned(),
-            )),
-            Expr::Literal(_) => Ok(expr.clone()),
-            Expr::Not(_) => Ok(expr.clone()),
-            Expr::Sort { .. } => Ok(expr.clone()),
-            Expr::Wildcard { .. } => Err(ExecutionError::General(
-                "Wildcard expressions are not valid in a logical query 
plan".to_owned(),
-            )),
-            Expr::Nested(e) => self.rewrite_expr(e, schema),
-        }
+            _ => {}
+        };
+        utils::from_expression(expr, &expressions)
     }
 }
 
 impl<'a> OptimizerRule for TypeCoercionRule<'a> {
     fn optimize(&mut self, plan: &LogicalPlan) -> Result<LogicalPlan> {
-        match plan {
-            LogicalPlan::Projection { expr, input, .. } => {
-                LogicalPlanBuilder::from(&self.optimize(input)?)
-                    .project(self.rewrite_expr_list(expr, input.schema())?)?
-                    .build()
-            }
-            LogicalPlan::Selection { expr, input, .. } => {
-                LogicalPlanBuilder::from(&self.optimize(input)?)
-                    .filter(self.rewrite_expr(expr, input.schema())?)?
-                    .build()
-            }
-            LogicalPlan::Aggregate {
-                input,
-                group_expr,
-                aggr_expr,
-                ..
-            } => LogicalPlanBuilder::from(&self.optimize(input)?)
-                .aggregate(
-                    self.rewrite_expr_list(group_expr, input.schema())?,
-                    self.rewrite_expr_list(aggr_expr, input.schema())?,
-                )?
-                .build(),
-            LogicalPlan::TableScan { .. } => Ok(plan.clone()),
-            LogicalPlan::InMemoryScan { .. } => Ok(plan.clone()),
-            LogicalPlan::ParquetScan { .. } => Ok(plan.clone()),
-            LogicalPlan::CsvScan { .. } => Ok(plan.clone()),
-            LogicalPlan::EmptyRelation { .. } => Ok(plan.clone()),
-            LogicalPlan::Limit { .. } => Ok(plan.clone()),
-            LogicalPlan::Sort { .. } => Ok(plan.clone()),
-            LogicalPlan::CreateExternalTable { .. } => Ok(plan.clone()),
-        }
+        let inputs = utils::inputs(plan);
+        let expressions = utils::expressions(plan);
+
+        // apply the optimization to all inputs of the plan
+        let new_inputs = inputs
+            .iter()
+            .map(|plan| self.optimize(*plan))
+            .collect::<Result<Vec<_>>>()?;
+        // re-write all expressions on this plan.
+        // This assumes a single input, [0]. It wont work for join, subqueries 
and union operations with more than one input.
+        // It is currently not an issue as we do not have any plan with more 
than one input.
+        let new_expressions = expressions

Review comment:
       ```suggestion
           assert!(expressions.len() == 0 || inputs.len() > 0, "Assume that all 
plan nodes with expressions had inputs");
           let new_expressions = expressions
   ```
   
   I think the `EmptyRelation`,  
https://github.com/apache/arrow/blob/master/rust/datafusion/src/logicalplan.rs#L761-L764,
 for example has no input LogicalPlan, but perhaps you are saying "even though 
`EmptyRelation` has no inputs (and thus could cause `inputs[0].schema()` to 
panic) it also has no Expressions then the potential panic'ing code won't be 
run. 
   
   I guess I was thinking to the  future where we add expressions to root nodes 
(e.g. perhaps filtering *during* a table scan or something) which would then 
have expressions but no input.
   
    I think this code is fine as is. Perhaps we could make the code slightly 
easier to work with in the future if we did something like the assert 
suggestion here that there were no inputs if there were expressions rather than 
panic. 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to