jorgecarleitao commented on a change in pull request #7880:
URL: https://github.com/apache/arrow/pull/7880#discussion_r470414491



##########
File path: rust/datafusion/src/optimizer/type_coercion.rs
##########
@@ -43,138 +45,77 @@ impl<'a> TypeCoercionRule<'a> {
         Self { scalar_functions }
     }
 
-    /// Rewrite an expression list to include explicit CAST operations when 
required
-    fn rewrite_expr_list(&self, expr: &[Expr], schema: &Schema) -> 
Result<Vec<Expr>> {
-        Ok(expr
+    /// Rewrite an expression to include explicit CAST operations when required
+    fn rewrite_expr(&self, expr: &Expr, schema: &Schema) -> Result<Expr> {
+        let expressions = utils::expr_expressions(expr)?;
+
+        // recurse of the re-write
+        let mut expressions = expressions
             .iter()
             .map(|e| self.rewrite_expr(e, schema))
-            .collect::<Result<Vec<_>>>()?)
-    }
+            .collect::<Result<Vec<_>>>()?;
 
-    /// Rewrite an expression to include explicit CAST operations when required
-    fn rewrite_expr(&self, expr: &Expr, schema: &Schema) -> Result<Expr> {
+        // modify `expressions` by introducing casts when necessary
         match expr {
-            Expr::BinaryExpr { left, op, right } => {
-                let left = self.rewrite_expr(left, schema)?;
-                let right = self.rewrite_expr(right, schema)?;
-                let left_type = left.get_type(schema)?;
-                let right_type = right.get_type(schema)?;
-                if left_type == right_type {
-                    Ok(Expr::BinaryExpr {
-                        left: Box::new(left),
-                        op: op.clone(),
-                        right: Box::new(right),
-                    })
-                } else {
+            Expr::BinaryExpr { .. } => {
+                let left_type = expressions[0].get_type(schema)?;
+                let right_type = expressions[1].get_type(schema)?;
+                if left_type != right_type {
                     let super_type = utils::get_supertype(&left_type, 
&right_type)?;
-                    Ok(Expr::BinaryExpr {
-                        left: Box::new(left.cast_to(&super_type, schema)?),
-                        op: op.clone(),
-                        right: Box::new(right.cast_to(&super_type, schema)?),
-                    })
+
+                    expressions[0] = expressions[0].cast_to(&super_type, 
schema)?;
+                    expressions[1] = expressions[1].cast_to(&super_type, 
schema)?;
                 }
             }
-            Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(self.rewrite_expr(e, 
schema)?))),
-            Expr::IsNotNull(e) => {
-                Ok(Expr::IsNotNull(Box::new(self.rewrite_expr(e, schema)?)))
-            }
-            Expr::ScalarFunction {
-                name,
-                args,
-                return_type,
-            } => {
+            Expr::ScalarFunction { name, .. } => {
                 // cast the inputs of scalar functions to the appropriate type 
where possible
                 match self.scalar_functions.get(name) {
                     Some(func_meta) => {
-                        let mut func_args = Vec::with_capacity(args.len());
-                        for i in 0..args.len() {
+                        for i in 0..expressions.len() {
                             let field = &func_meta.args[i];
-                            let expr = self.rewrite_expr(&args[i], schema)?;
-                            let actual_type = expr.get_type(schema)?;
+                            let actual_type = expressions[i].get_type(schema)?;
                             let required_type = field.data_type();
-                            if &actual_type == required_type {
-                                func_args.push(expr)
-                            } else {
+                            if &actual_type != required_type {
                                 let super_type =
                                     utils::get_supertype(&actual_type, 
required_type)?;
-                                func_args.push(expr.cast_to(&super_type, 
schema)?);
-                            }
+                                expressions[i] =
+                                    expressions[i].cast_to(&super_type, 
schema)?
+                            };
                         }
-
-                        Ok(Expr::ScalarFunction {
-                            name: name.clone(),
-                            args: func_args,
-                            return_type: return_type.clone(),
-                        })
                     }
-                    _ => Err(ExecutionError::General(format!(
-                        "Invalid scalar function {}",
-                        name
-                    ))),
+                    _ => {
+                        return Err(ExecutionError::General(format!(
+                            "Invalid scalar function {}",
+                            name
+                        )))
+                    }
                 }
             }
-            Expr::AggregateFunction {
-                name,
-                args,
-                return_type,
-            } => Ok(Expr::AggregateFunction {
-                name: name.clone(),
-                args: args
-                    .iter()
-                    .map(|a| self.rewrite_expr(a, schema))
-                    .collect::<Result<Vec<_>>>()?,
-                return_type: return_type.clone(),
-            }),
-            Expr::Cast { .. } => Ok(expr.clone()),
-            Expr::Column(_) => Ok(expr.clone()),
-            Expr::Alias(expr, alias) => Ok(Expr::Alias(
-                Box::new(self.rewrite_expr(expr, schema)?),
-                alias.to_owned(),
-            )),
-            Expr::Literal(_) => Ok(expr.clone()),
-            Expr::Not(_) => Ok(expr.clone()),
-            Expr::Sort { .. } => Ok(expr.clone()),
-            Expr::Wildcard { .. } => Err(ExecutionError::General(
-                "Wildcard expressions are not valid in a logical query 
plan".to_owned(),
-            )),
-            Expr::Nested(e) => self.rewrite_expr(e, schema),
-        }
+            _ => {}
+        };
+        utils::from_expression(expr, &expressions)
     }
 }
 
 impl<'a> OptimizerRule for TypeCoercionRule<'a> {
     fn optimize(&mut self, plan: &LogicalPlan) -> Result<LogicalPlan> {
-        match plan {
-            LogicalPlan::Projection { expr, input, .. } => {
-                LogicalPlanBuilder::from(&self.optimize(input)?)
-                    .project(self.rewrite_expr_list(expr, input.schema())?)?
-                    .build()
-            }
-            LogicalPlan::Selection { expr, input, .. } => {
-                LogicalPlanBuilder::from(&self.optimize(input)?)
-                    .filter(self.rewrite_expr(expr, input.schema())?)?
-                    .build()
-            }
-            LogicalPlan::Aggregate {
-                input,
-                group_expr,
-                aggr_expr,
-                ..
-            } => LogicalPlanBuilder::from(&self.optimize(input)?)
-                .aggregate(
-                    self.rewrite_expr_list(group_expr, input.schema())?,
-                    self.rewrite_expr_list(aggr_expr, input.schema())?,
-                )?
-                .build(),
-            LogicalPlan::TableScan { .. } => Ok(plan.clone()),
-            LogicalPlan::InMemoryScan { .. } => Ok(plan.clone()),
-            LogicalPlan::ParquetScan { .. } => Ok(plan.clone()),
-            LogicalPlan::CsvScan { .. } => Ok(plan.clone()),
-            LogicalPlan::EmptyRelation { .. } => Ok(plan.clone()),
-            LogicalPlan::Limit { .. } => Ok(plan.clone()),
-            LogicalPlan::Sort { .. } => Ok(plan.clone()),
-            LogicalPlan::CreateExternalTable { .. } => Ok(plan.clone()),
-        }
+        let inputs = utils::inputs(plan);
+        let expressions = utils::expressions(plan);
+
+        // apply the optimization to all inputs of the plan
+        let new_inputs = inputs
+            .iter()
+            .map(|plan| self.optimize(*plan))
+            .collect::<Result<Vec<_>>>()?;
+        // re-write all expressions on this plan.
+        // This assumes a single input, [0]. It wont work for join, subqueries 
and union operations with more than one input.
+        // It is currently not an issue as we do not have any plan with more 
than one input.
+        let new_expressions = expressions

Review comment:
       Good catch. No because I am unsure how that could be possible: if we 
have expressions on a plan, we need an `input` to convert them to physical 
expressions and evaluate them against. AFAIK an expression always requires an 
input to be evaluated against. 
   
   Do you have an example in mind?
   
   AFAIK even a literal expression requires a schema to pass to 
`Expr::get_type` and `Expr::name`.
   
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to