alamb commented on a change in pull request #7880: URL: https://github.com/apache/arrow/pull/7880#discussion_r470753855
########## File path: rust/datafusion/src/optimizer/type_coercion.rs ########## @@ -43,138 +45,77 @@ impl<'a> TypeCoercionRule<'a> { Self { scalar_functions } } - /// Rewrite an expression list to include explicit CAST operations when required - fn rewrite_expr_list(&self, expr: &[Expr], schema: &Schema) -> Result<Vec<Expr>> { - Ok(expr + /// Rewrite an expression to include explicit CAST operations when required + fn rewrite_expr(&self, expr: &Expr, schema: &Schema) -> Result<Expr> { + let expressions = utils::expr_expressions(expr)?; + + // recurse of the re-write + let mut expressions = expressions .iter() .map(|e| self.rewrite_expr(e, schema)) - .collect::<Result<Vec<_>>>()?) - } + .collect::<Result<Vec<_>>>()?; - /// Rewrite an expression to include explicit CAST operations when required - fn rewrite_expr(&self, expr: &Expr, schema: &Schema) -> Result<Expr> { + // modify `expressions` by introducing casts when necessary match expr { - Expr::BinaryExpr { left, op, right } => { - let left = self.rewrite_expr(left, schema)?; - let right = self.rewrite_expr(right, schema)?; - let left_type = left.get_type(schema)?; - let right_type = right.get_type(schema)?; - if left_type == right_type { - Ok(Expr::BinaryExpr { - left: Box::new(left), - op: op.clone(), - right: Box::new(right), - }) - } else { + Expr::BinaryExpr { .. } => { + let left_type = expressions[0].get_type(schema)?; + let right_type = expressions[1].get_type(schema)?; + if left_type != right_type { let super_type = utils::get_supertype(&left_type, &right_type)?; - Ok(Expr::BinaryExpr { - left: Box::new(left.cast_to(&super_type, schema)?), - op: op.clone(), - right: Box::new(right.cast_to(&super_type, schema)?), - }) + + expressions[0] = expressions[0].cast_to(&super_type, schema)?; + expressions[1] = expressions[1].cast_to(&super_type, schema)?; } } - Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(self.rewrite_expr(e, schema)?))), - Expr::IsNotNull(e) => { - Ok(Expr::IsNotNull(Box::new(self.rewrite_expr(e, schema)?))) - } - Expr::ScalarFunction { - name, - args, - return_type, - } => { + Expr::ScalarFunction { name, .. } => { // cast the inputs of scalar functions to the appropriate type where possible match self.scalar_functions.get(name) { Some(func_meta) => { - let mut func_args = Vec::with_capacity(args.len()); - for i in 0..args.len() { + for i in 0..expressions.len() { let field = &func_meta.args[i]; - let expr = self.rewrite_expr(&args[i], schema)?; - let actual_type = expr.get_type(schema)?; + let actual_type = expressions[i].get_type(schema)?; let required_type = field.data_type(); - if &actual_type == required_type { - func_args.push(expr) - } else { + if &actual_type != required_type { let super_type = utils::get_supertype(&actual_type, required_type)?; - func_args.push(expr.cast_to(&super_type, schema)?); - } + expressions[i] = + expressions[i].cast_to(&super_type, schema)? + }; } - - Ok(Expr::ScalarFunction { - name: name.clone(), - args: func_args, - return_type: return_type.clone(), - }) } - _ => Err(ExecutionError::General(format!( - "Invalid scalar function {}", - name - ))), + _ => { + return Err(ExecutionError::General(format!( + "Invalid scalar function {}", + name + ))) + } } } - Expr::AggregateFunction { - name, - args, - return_type, - } => Ok(Expr::AggregateFunction { - name: name.clone(), - args: args - .iter() - .map(|a| self.rewrite_expr(a, schema)) - .collect::<Result<Vec<_>>>()?, - return_type: return_type.clone(), - }), - Expr::Cast { .. } => Ok(expr.clone()), - Expr::Column(_) => Ok(expr.clone()), - Expr::Alias(expr, alias) => Ok(Expr::Alias( - Box::new(self.rewrite_expr(expr, schema)?), - alias.to_owned(), - )), - Expr::Literal(_) => Ok(expr.clone()), - Expr::Not(_) => Ok(expr.clone()), - Expr::Sort { .. } => Ok(expr.clone()), - Expr::Wildcard { .. } => Err(ExecutionError::General( - "Wildcard expressions are not valid in a logical query plan".to_owned(), - )), - Expr::Nested(e) => self.rewrite_expr(e, schema), - } + _ => {} + }; + utils::from_expression(expr, &expressions) } } impl<'a> OptimizerRule for TypeCoercionRule<'a> { fn optimize(&mut self, plan: &LogicalPlan) -> Result<LogicalPlan> { - match plan { - LogicalPlan::Projection { expr, input, .. } => { - LogicalPlanBuilder::from(&self.optimize(input)?) - .project(self.rewrite_expr_list(expr, input.schema())?)? - .build() - } - LogicalPlan::Selection { expr, input, .. } => { - LogicalPlanBuilder::from(&self.optimize(input)?) - .filter(self.rewrite_expr(expr, input.schema())?)? - .build() - } - LogicalPlan::Aggregate { - input, - group_expr, - aggr_expr, - .. - } => LogicalPlanBuilder::from(&self.optimize(input)?) - .aggregate( - self.rewrite_expr_list(group_expr, input.schema())?, - self.rewrite_expr_list(aggr_expr, input.schema())?, - )? - .build(), - LogicalPlan::TableScan { .. } => Ok(plan.clone()), - LogicalPlan::InMemoryScan { .. } => Ok(plan.clone()), - LogicalPlan::ParquetScan { .. } => Ok(plan.clone()), - LogicalPlan::CsvScan { .. } => Ok(plan.clone()), - LogicalPlan::EmptyRelation { .. } => Ok(plan.clone()), - LogicalPlan::Limit { .. } => Ok(plan.clone()), - LogicalPlan::Sort { .. } => Ok(plan.clone()), - LogicalPlan::CreateExternalTable { .. } => Ok(plan.clone()), - } + let inputs = utils::inputs(plan); + let expressions = utils::expressions(plan); + + // apply the optimization to all inputs of the plan + let new_inputs = inputs + .iter() + .map(|plan| self.optimize(*plan)) + .collect::<Result<Vec<_>>>()?; + // re-write all expressions on this plan. + // This assumes a single input, [0]. It wont work for join, subqueries and union operations with more than one input. + // It is currently not an issue as we do not have any plan with more than one input. + let new_expressions = expressions Review comment: ```suggestion assert!(expressions.len() == 0 || inputs.len() > 0, "Assume that all plan nodes with expressions had inputs"); let new_expressions = expressions ``` I think the `EmptyRelation`, https://github.com/apache/arrow/blob/master/rust/datafusion/src/logicalplan.rs#L761-L764, for example has no input LogicalPlan, but perhaps you are saying "even though `EmptyRelation` has no inputs (and thus could cause `inputs[0].schema()` to panic) it also has no Expressions then the potential panic'ing code won't be run. I guess I was thinking to the future where we add expressions to root nodes (e.g. perhaps filtering *during* a table scan or something) which would then have expressions but no input. I think this code is fine as is. Perhaps we could make the code slightly easier to work with in the future if we did something like the assert suggestion here that there were no inputs if there were expressions rather than panic. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org