berkaysynnada commented on code in PR #8109:
URL: https://github.com/apache/arrow-datafusion/pull/8109#discussion_r1389035734
##########
datafusion/core/src/physical_optimizer/projection_pushdown.rs:
##########
@@ -791,119 +788,65 @@ fn update_expr(
projected_exprs: &[(Arc<dyn PhysicalExpr>, String)],
sync_with_child: bool,
) -> Result<Option<Arc<dyn PhysicalExpr>>> {
- let expr_any = expr.as_any();
- if let Some(column) = expr_any.downcast_ref::<Column>() {
- if sync_with_child {
- // Update the index of `column`:
- Ok(Some(projected_exprs[column.index()].0.clone()))
- } else {
- // Determine how to update `column` to accommodate
`projected_exprs`:
- Ok(projected_exprs.iter().enumerate().find_map(
- |(index, (projected_expr, alias))| {
- projected_expr.as_any().downcast_ref::<Column>().and_then(
- |projected_column| {
- column
- .name()
- .eq(projected_column.name())
- .then(|| Arc::new(Column::new(alias, index))
as _)
- },
- )
- },
- ))
- }
- } else if let Some(binary) = expr_any.downcast_ref::<BinaryExpr>() {
- match (
- update_expr(binary.left(), projected_exprs, sync_with_child)?,
- update_expr(binary.right(), projected_exprs, sync_with_child)?,
- ) {
- (Some(left), Some(right)) => {
- Ok(Some(Arc::new(BinaryExpr::new(left, *binary.op(), right))))
- }
- _ => Ok(None),
- }
- } else if let Some(cast) = expr_any.downcast_ref::<CastExpr>() {
- update_expr(cast.expr(), projected_exprs,
sync_with_child).map(|maybe_expr| {
- maybe_expr.map(|expr| {
- Arc::new(CastExpr::new(
- expr,
- cast.cast_type().clone(),
- Some(cast.cast_options().clone()),
- )) as _
- })
- })
- } else if expr_any.is::<Literal>() {
- Ok(Some(expr.clone()))
- } else if let Some(negative) = expr_any.downcast_ref::<NegativeExpr>() {
- update_expr(negative.arg(), projected_exprs,
sync_with_child).map(|maybe_expr| {
- maybe_expr.map(|expr| Arc::new(NegativeExpr::new(expr)) as _)
- })
- } else if let Some(scalar_func) =
expr_any.downcast_ref::<ScalarFunctionExpr>() {
- scalar_func
- .args()
- .iter()
- .map(|expr| update_expr(expr, projected_exprs, sync_with_child))
- .collect::<Result<Option<Vec<_>>>>()
- .map(|maybe_args| {
- maybe_args.map(|new_args| {
- Arc::new(ScalarFunctionExpr::new(
- scalar_func.name(),
- scalar_func.fun().clone(),
- new_args,
- scalar_func.return_type(),
- scalar_func.monotonicity().clone(),
- )) as _
- })
- })
- } else if let Some(case) = expr_any.downcast_ref::<CaseExpr>() {
- update_case_expr(case, projected_exprs, sync_with_child)
- } else {
- Ok(None)
+ #[derive(Debug, PartialEq)]
+ enum RewriteState {
+ /// The expression is unchanged.
+ Unchanged,
+ /// Some part of the expression has been rewritten
+ RewrittenValid,
+ /// Some part of the expression has been rewritten, but some column
+ /// references could not be.
+ RewrittenInvalid,
}
-}
-/// Updates the indices `case` refers to according to `projected_exprs`.
-fn update_case_expr(
- case: &CaseExpr,
- projected_exprs: &[(Arc<dyn PhysicalExpr>, String)],
- sync_with_child: bool,
-) -> Result<Option<Arc<dyn PhysicalExpr>>> {
- let new_case = case
- .expr()
- .map(|expr| update_expr(expr, projected_exprs, sync_with_child))
- .transpose()?
- .flatten();
-
- let new_else = case
- .else_expr()
- .map(|expr| update_expr(expr, projected_exprs, sync_with_child))
- .transpose()?
- .flatten();
-
- let new_when_then = case
- .when_then_expr()
- .iter()
- .map(|(when, then)| {
- Ok((
- update_expr(when, projected_exprs, sync_with_child)?,
- update_expr(then, projected_exprs, sync_with_child)?,
- ))
- })
- .collect::<Result<Vec<_>>>()?
- .into_iter()
- .filter_map(|(maybe_when, maybe_then)| match (maybe_when, maybe_then) {
- (Some(when), Some(then)) => Some((when, then)),
- _ => None,
- })
- .collect::<Vec<_>>();
+ let mut state = RewriteState::Unchanged;
- if new_when_then.len() != case.when_then_expr().len()
- || case.expr().is_some() && new_case.is_none()
- || case.else_expr().is_some() && new_else.is_none()
- {
- return Ok(None);
- }
+ let new_expr = expr
+ .clone()
+ .transform_up_mut(&mut |expr: Arc<dyn PhysicalExpr>| {
+ if state == RewriteState::RewrittenInvalid {
+ return Ok(Transformed::No(expr));
+ }
- CaseExpr::try_new(new_case, new_when_then, new_else).map(|e|
Some(Arc::new(e) as _))
+ let Some(column) = expr.as_any().downcast_ref::<Column>() else {
+ return Ok(Transformed::No(expr));
+ };
+ if sync_with_child {
+ state = RewriteState::RewrittenValid;
+ // Update the index of `column`:
+ Ok(Transformed::Yes(projected_exprs[column.index()].0.clone()))
+ } else {
+ // Determine how to update `column` to accommodate
`projected_exprs`:
+ let new_col = projected_exprs.iter().enumerate().find_map(
+ |(index, (projected_expr, alias))| {
+
projected_expr.as_any().downcast_ref::<Column>().and_then(
+ |projected_column| {
+ column
+ .name()
+ .eq(projected_column.name())
+ .then(|| Arc::new(Column::new(alias,
index)) as _)
+ },
+ )
+ },
+ );
+ if let Some(new_col) = new_col {
+ state = RewriteState::RewrittenValid;
+ Ok(Transformed::Yes(new_col))
+ } else {
+ // didn't find a rewrite, stop trying
+ state = RewriteState::RewrittenInvalid;
+ Ok(Transformed::No(expr))
+ }
Review Comment:
```suggestion
state = RewriteState::RewrittenInvalid;
projected_exprs
.iter()
.enumerate()
.find_map(|(index, (projected_expr, alias))| {
projected_expr.as_any().downcast_ref::<Column>().and_then(
|projected_column| {
column.name().eq(projected_column.name()).then(|| {
state = RewriteState::RewrittenValid;
Arc::new(Column::new(alias, index)) as _
})
},
)
})
.map_or_else(|_| Ok(Transformed::No(expr)), |c|
Ok(Transformed::Yes(c)))
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]