This is an automated email from the ASF dual-hosted git repository.

jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 88c98e1c1e Refactor `UnwrapCastInComparison` to remove `Expr` clones 
(#10115)
88c98e1c1e is described below

commit 88c98e1c1ecec357548a89022053d3735568a853
Author: Peter Toth <[email protected]>
AuthorDate: Thu Apr 18 02:40:23 2024 +0200

    Refactor `UnwrapCastInComparison` to remove `Expr` clones (#10115)
---
 .../optimizer/src/unwrap_cast_in_comparison.rs     | 243 ++++++++++-----------
 1 file changed, 117 insertions(+), 126 deletions(-)

diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs 
b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
index 5ede43a051..bd14584fd5 100644
--- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
+++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
@@ -18,6 +18,7 @@
 //! [`UnwrapCastInComparison`] rewrites `CAST(col) = lit` to `col = CAST(lit)`
 
 use std::cmp::Ordering;
+use std::mem;
 use std::sync::Arc;
 
 use crate::optimizer::ApplyOrder;
@@ -32,9 +33,7 @@ use datafusion_common::tree_node::{Transformed, TreeNode, 
TreeNodeRewriter};
 use datafusion_common::{internal_err, DFSchema, DFSchemaRef, Result, 
ScalarValue};
 use datafusion_expr::expr::{BinaryExpr, Cast, InList, TryCast};
 use datafusion_expr::utils::merge_schema;
-use datafusion_expr::{
-    binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator,
-};
+use datafusion_expr::{lit, Expr, ExprSchemable, LogicalPlan, Operator};
 
 /// [`UnwrapCastInComparison`] attempts to remove casts from
 /// comparisons to literals ([`ScalarValue`]s) by applying the casts
@@ -140,140 +139,132 @@ struct UnwrapCastExprRewriter {
 impl TreeNodeRewriter for UnwrapCastExprRewriter {
     type Node = Expr;
 
-    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
-        match &expr {
+    fn f_up(&mut self, mut expr: Expr) -> Result<Transformed<Expr>> {
+        match &mut expr {
             // For case:
             // try_cast/cast(expr as data_type) op literal
             // literal op try_cast/cast(expr as data_type)
-            Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
-                let left = left.as_ref().clone();
-                let right = right.as_ref().clone();
-                let left_type = left.get_type(&self.schema)?;
-                let right_type = right.get_type(&self.schema)?;
-                // Because the plan has been done the type coercion, the left 
and right must be equal
-                if is_support_data_type(&left_type)
-                    && is_support_data_type(&right_type)
-                    && is_comparison_op(op)
-                {
-                    match (&left, &right) {
-                        (
-                            Expr::Literal(left_lit_value),
-                            Expr::TryCast(TryCast { expr, .. })
-                            | Expr::Cast(Cast { expr, .. }),
-                        ) => {
-                            // if the left_lit_value can be casted to the type 
of expr
-                            // we need to unwrap the cast for cast/try_cast 
expr, and add cast to the literal
-                            let expr_type = expr.get_type(&self.schema)?;
-                            let casted_scalar_value =
-                                try_cast_literal_to_type(left_lit_value, 
&expr_type)?;
-                            if let Some(value) = casted_scalar_value {
-                                // unwrap the cast/try_cast for the right expr
-                                return Ok(Transformed::yes(binary_expr(
-                                    lit(value),
-                                    *op,
-                                    expr.as_ref().clone(),
-                                )));
-                            }
-                        }
-                        (
-                            Expr::TryCast(TryCast { expr, .. })
-                            | Expr::Cast(Cast { expr, .. }),
-                            Expr::Literal(right_lit_value),
-                        ) => {
-                            // if the right_lit_value can be casted to the 
type of expr
-                            // we need to unwrap the cast for cast/try_cast 
expr, and add cast to the literal
-                            let expr_type = expr.get_type(&self.schema)?;
-                            let casted_scalar_value =
-                                try_cast_literal_to_type(right_lit_value, 
&expr_type)?;
-                            if let Some(value) = casted_scalar_value {
-                                // unwrap the cast/try_cast for the left expr
-                                return Ok(Transformed::yes(binary_expr(
-                                    expr.as_ref().clone(),
-                                    *op,
-                                    lit(value),
-                                )));
-                            }
-                        }
-                        (_, _) => {
-                            // do nothing
-                        }
+            Expr::BinaryExpr(BinaryExpr { left, op, right })
+                if {
+                    let Ok(left_type) = left.get_type(&self.schema) else {
+                        return Ok(Transformed::no(expr));
                     };
+                    let Ok(right_type) = right.get_type(&self.schema) else {
+                        return Ok(Transformed::no(expr));
+                    };
+                    is_support_data_type(&left_type)
+                        && is_support_data_type(&right_type)
+                        && is_comparison_op(op)
+                } =>
+            {
+                match (left.as_mut(), right.as_mut()) {
+                    (
+                        Expr::Literal(left_lit_value),
+                        Expr::TryCast(TryCast {
+                            expr: right_expr, ..
+                        })
+                        | Expr::Cast(Cast {
+                            expr: right_expr, ..
+                        }),
+                    ) => {
+                        // if the left_lit_value can be casted to the type of 
expr
+                        // we need to unwrap the cast for cast/try_cast expr, 
and add cast to the literal
+                        let Ok(expr_type) = right_expr.get_type(&self.schema) 
else {
+                            return Ok(Transformed::no(expr));
+                        };
+                        let Ok(Some(value)) =
+                            try_cast_literal_to_type(left_lit_value, 
&expr_type)
+                        else {
+                            return Ok(Transformed::no(expr));
+                        };
+                        **left = lit(value);
+                        // unwrap the cast/try_cast for the right expr
+                        **right =
+                            mem::replace(right_expr, 
Expr::Literal(ScalarValue::Null));
+                        Ok(Transformed::yes(expr))
+                    }
+                    (
+                        Expr::TryCast(TryCast {
+                            expr: left_expr, ..
+                        })
+                        | Expr::Cast(Cast {
+                            expr: left_expr, ..
+                        }),
+                        Expr::Literal(right_lit_value),
+                    ) => {
+                        // if the right_lit_value can be casted to the type of 
expr
+                        // we need to unwrap the cast for cast/try_cast expr, 
and add cast to the literal
+                        let Ok(expr_type) = left_expr.get_type(&self.schema) 
else {
+                            return Ok(Transformed::no(expr));
+                        };
+                        let Ok(Some(value)) =
+                            try_cast_literal_to_type(right_lit_value, 
&expr_type)
+                        else {
+                            return Ok(Transformed::no(expr));
+                        };
+                        // unwrap the cast/try_cast for the left expr
+                        **left =
+                            mem::replace(left_expr, 
Expr::Literal(ScalarValue::Null));
+                        **right = lit(value);
+                        Ok(Transformed::yes(expr))
+                    }
+                    _ => Ok(Transformed::no(expr)),
                 }
-                // return the new binary op
-                Ok(Transformed::yes(binary_expr(left, *op, right)))
             }
             // For case:
             // try_cast/cast(expr as left_type) in (expr1,expr2,expr3)
             Expr::InList(InList {
-                expr: left_expr,
-                list,
-                negated,
+                expr: left, list, ..
             }) => {
-                if let Some(
-                    Expr::TryCast(TryCast {
-                        expr: internal_left_expr,
-                        ..
-                    })
-                    | Expr::Cast(Cast {
-                        expr: internal_left_expr,
-                        ..
-                    }),
-                ) = Some(left_expr.as_ref())
-                {
-                    let internal_left = internal_left_expr.as_ref().clone();
-                    let internal_left_type = 
internal_left.get_type(&self.schema);
-                    if internal_left_type.is_err() {
-                        // error data type
-                        return Ok(Transformed::no(expr));
-                    }
-                    let internal_left_type = internal_left_type?;
-                    if !is_support_data_type(&internal_left_type) {
-                        // not supported data type
-                        return Ok(Transformed::no(expr));
-                    }
-                    let right_exprs = list
-                        .iter()
-                        .map(|right| {
-                            let right_type = right.get_type(&self.schema)?;
-                            if !is_support_data_type(&right_type) {
-                                return internal_err!(
-                                    "The type of list expr {} not support",
-                                    &right_type
-                                );
-                            }
-                            match right {
-                                Expr::Literal(right_lit_value) => {
-                                    // if the right_lit_value can be casted to 
the type of internal_left_expr
-                                    // we need to unwrap the cast for 
cast/try_cast expr, and add cast to the literal
-                                    let casted_scalar_value =
-                                        
try_cast_literal_to_type(right_lit_value, &internal_left_type)?;
-                                    if let Some(value) = casted_scalar_value {
-                                        Ok(lit(value))
-                                    } else {
-                                        internal_err!(
-                                            "Can't cast the list expr {:?} to 
type {:?}",
-                                            right_lit_value, 
&internal_left_type
-                                        )
-                                    }
-                                }
-                                other_expr => internal_err!(
-                                    "Only support literal expr to optimize, 
but the expr is {:?}",
-                                    &other_expr
-                                ),
-                            }
-                        })
-                        .collect::<Result<Vec<_>>>();
-                    match right_exprs {
-                        Ok(right_exprs) => Ok(Transformed::yes(in_list(
-                            internal_left,
-                            right_exprs,
-                            *negated,
-                        ))),
-                        Err(_) => Ok(Transformed::no(expr)),
-                    }
-                } else {
-                    Ok(Transformed::no(expr))
+                let (Expr::TryCast(TryCast {
+                    expr: left_expr, ..
+                })
+                | Expr::Cast(Cast {
+                    expr: left_expr, ..
+                })) = left.as_mut()
+                else {
+                    return Ok(Transformed::no(expr));
+                };
+                let Ok(expr_type) = left_expr.get_type(&self.schema) else {
+                    return Ok(Transformed::no(expr));
+                };
+                if !is_support_data_type(&expr_type) {
+                    return Ok(Transformed::no(expr));
                 }
+                let Ok(right_exprs) = list
+                    .iter()
+                    .map(|right| {
+                        let right_type = right.get_type(&self.schema)?;
+                        if !is_support_data_type(&right_type) {
+                            internal_err!(
+                                "The type of list expr {} is not supported",
+                                &right_type
+                            )?;
+                        }
+                        match right {
+                            Expr::Literal(right_lit_value) => {
+                                // if the right_lit_value can be casted to the 
type of internal_left_expr
+                                // we need to unwrap the cast for 
cast/try_cast expr, and add cast to the literal
+                                let Ok(Some(value)) = 
try_cast_literal_to_type(right_lit_value, &expr_type) else {
+                                    internal_err!(
+                                        "Can't cast the list expr {:?} to type 
{:?}",
+                                        right_lit_value, &expr_type
+                                    )?
+                                };
+                                Ok(lit(value))
+                            }
+                            other_expr => internal_err!(
+                                "Only support literal expr to optimize, but 
the expr is {:?}",
+                                &other_expr
+                            ),
+                        }
+                    })
+                    .collect::<Result<Vec<_>>>() else {
+                    return Ok(Transformed::no(expr))
+                };
+                **left = mem::replace(left_expr, 
Expr::Literal(ScalarValue::Null));
+                *list = right_exprs;
+                Ok(Transformed::yes(expr))
             }
             // TODO: handle other expr type and dfs visit them
             _ => Ok(Transformed::no(expr)),

Reply via email to