This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new eb4ee6272c Move `UnwrapCastInComparison` into `Simplifier` (#15012)
eb4ee6272c is described below

commit eb4ee6272c77a2724c75edc714a93a1dd3e2c13d
Author: Jay Zhan <[email protected]>
AuthorDate: Thu Mar 6 19:35:20 2025 +0800

    Move `UnwrapCastInComparison` into `Simplifier` (#15012)
    
    * add unwrap in simplify expr
    
    * rm unwrap cast
    
    * return err
    
    * rename
    
    * fix
    
    * fmt
    
    * add unwrap_cast module to simplify expressions
    
    * tweak comment
    
    * Move tests
    
    * Rewrite to use simplifier schema
    
    * Update tests for simplify logic
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/core/tests/sql/explain_analyze.rs       |   7 +-
 datafusion/optimizer/src/lib.rs                    |   1 -
 datafusion/optimizer/src/optimizer.rs              |   3 -
 .../src/simplify_expressions/expr_simplifier.rs    |  92 +++-
 .../optimizer/src/simplify_expressions/mod.rs      |   1 +
 .../unwrap_cast.rs}                                | 461 +++++++++------------
 datafusion/sqllogictest/test_files/explain.slt     |   4 -
 7 files changed, 291 insertions(+), 278 deletions(-)

diff --git a/datafusion/core/tests/sql/explain_analyze.rs 
b/datafusion/core/tests/sql/explain_analyze.rs
index 3bdc71a8eb..e8ef34c2af 100644
--- a/datafusion/core/tests/sql/explain_analyze.rs
+++ b/datafusion/core/tests/sql/explain_analyze.rs
@@ -355,7 +355,8 @@ async fn csv_explain_verbose() {
 async fn csv_explain_inlist_verbose() {
     let ctx = SessionContext::new();
     register_aggregate_csv_by_sql(&ctx).await;
-    let sql = "EXPLAIN VERBOSE SELECT c1 FROM aggregate_test_100 where c2 in 
(1,2,4)";
+    // Inlist len <=3 case will be transformed to OR List so we test with len=4
+    let sql = "EXPLAIN VERBOSE SELECT c1 FROM aggregate_test_100 where c2 in 
(1,2,4,5)";
     let actual = execute(&ctx, sql).await;
 
     // Optimized by PreCastLitInComparisonExpressions rule
@@ -368,12 +369,12 @@ async fn csv_explain_inlist_verbose() {
     // before optimization (Int64 literals)
     assert_contains!(
         &actual,
-        "aggregate_test_100.c2 IN ([Int64(1), Int64(2), Int64(4)])"
+        "aggregate_test_100.c2 IN ([Int64(1), Int64(2), Int64(4), Int64(5)])"
     );
     // after optimization (casted to Int8)
     assert_contains!(
         &actual,
-        "aggregate_test_100.c2 IN ([Int8(1), Int8(2), Int8(4)])"
+        "aggregate_test_100.c2 IN ([Int8(1), Int8(2), Int8(4), Int8(5)])"
     );
 }
 
diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs
index 61ca9b31cd..1280bf2f46 100644
--- a/datafusion/optimizer/src/lib.rs
+++ b/datafusion/optimizer/src/lib.rs
@@ -60,7 +60,6 @@ pub mod replace_distinct_aggregate;
 pub mod scalar_subquery_to_join;
 pub mod simplify_expressions;
 pub mod single_distinct_to_groupby;
-pub mod unwrap_cast_in_comparison;
 pub mod utils;
 
 #[cfg(test)]
diff --git a/datafusion/optimizer/src/optimizer.rs 
b/datafusion/optimizer/src/optimizer.rs
index 49bce3c1ce..018ad8ace0 100644
--- a/datafusion/optimizer/src/optimizer.rs
+++ b/datafusion/optimizer/src/optimizer.rs
@@ -54,7 +54,6 @@ use 
crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate;
 use crate::scalar_subquery_to_join::ScalarSubqueryToJoin;
 use crate::simplify_expressions::SimplifyExpressions;
 use crate::single_distinct_to_groupby::SingleDistinctToGroupBy;
-use crate::unwrap_cast_in_comparison::UnwrapCastInComparison;
 use crate::utils::log_plan;
 
 /// `OptimizerRule`s transforms one [`LogicalPlan`] into another which
@@ -243,7 +242,6 @@ impl Optimizer {
         let rules: Vec<Arc<dyn OptimizerRule + Sync + Send>> = vec![
             Arc::new(EliminateNestedUnion::new()),
             Arc::new(SimplifyExpressions::new()),
-            Arc::new(UnwrapCastInComparison::new()),
             Arc::new(ReplaceDistinctWithAggregate::new()),
             Arc::new(EliminateJoin::new()),
             Arc::new(DecorrelatePredicateSubquery::new()),
@@ -266,7 +264,6 @@ impl Optimizer {
             // The previous optimizations added expressions and projections,
             // that might benefit from the following rules
             Arc::new(SimplifyExpressions::new()),
-            Arc::new(UnwrapCastInComparison::new()),
             Arc::new(CommonSubexprEliminate::new()),
             Arc::new(EliminateGroupByConstant::new()),
             Arc::new(OptimizeProjections::new()),
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs 
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index 840c108905..d5a1b84e6a 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -32,7 +32,6 @@ use datafusion_common::{
     tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter},
 };
 use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, 
ScalarValue};
-use datafusion_expr::simplify::ExprSimplifyResult;
 use datafusion_expr::{
     and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, 
Volatility,
     WindowFunctionDefinition,
@@ -42,14 +41,23 @@ use datafusion_expr::{
     expr::{InList, InSubquery, WindowFunction},
     utils::{iter_conjunction, iter_conjunction_owned},
 };
+use datafusion_expr::{simplify::ExprSimplifyResult, Cast, TryCast};
 use datafusion_physical_expr::{create_physical_expr, 
execution_props::ExecutionProps};
 
 use super::inlist_simplifier::ShortenInListSimplifier;
 use super::utils::*;
-use crate::analyzer::type_coercion::TypeCoercionRewriter;
 use crate::simplify_expressions::guarantees::GuaranteeRewriter;
 use crate::simplify_expressions::regex::simplify_regex_expr;
+use crate::simplify_expressions::unwrap_cast::{
+    is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary,
+    is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist,
+    unwrap_cast_in_comparison_for_binary,
+};
 use crate::simplify_expressions::SimplifyInfo;
+use crate::{
+    analyzer::type_coercion::TypeCoercionRewriter,
+    simplify_expressions::unwrap_cast::try_cast_literal_to_type,
+};
 use indexmap::IndexSet;
 use regex::Regex;
 
@@ -1742,6 +1750,86 @@ impl<S: SimplifyInfo> TreeNodeRewriter for 
Simplifier<'_, S> {
                 }
             }
 
+            // =======================================
+            // unwrap_cast_in_comparison
+            // =======================================
+            //
+            // For case:
+            // try_cast/cast(expr as data_type) op literal
+            Expr::BinaryExpr(BinaryExpr { left, op, right })
+                if 
is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
+                    info, &left, &right,
+                ) && op.supports_propagation() =>
+            {
+                unwrap_cast_in_comparison_for_binary(info, left, right, op)?
+            }
+            // literal op try_cast/cast(expr as data_type)
+            // -->
+            // try_cast/cast(expr as data_type) op_swap literal
+            Expr::BinaryExpr(BinaryExpr { left, op, right })
+                if 
is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
+                    info, &right, &left,
+                ) && op.supports_propagation()
+                    && op.swap().is_some() =>
+            {
+                unwrap_cast_in_comparison_for_binary(
+                    info,
+                    right,
+                    left,
+                    op.swap().unwrap(),
+                )?
+            }
+            // For case:
+            // try_cast/cast(expr as left_type) in (expr1,expr2,expr3)
+            Expr::InList(InList {
+                expr: mut left,
+                list,
+                negated,
+            }) if 
is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist(
+                info, &left, &list,
+            ) =>
+            {
+                let (Expr::TryCast(TryCast {
+                    expr: left_expr, ..
+                })
+                | Expr::Cast(Cast {
+                    expr: left_expr, ..
+                })) = left.as_mut()
+                else {
+                    return internal_err!("Expect cast expr, but got {:?}", 
left)?;
+                };
+
+                let expr_type = info.get_data_type(left_expr)?;
+                let right_exprs = list
+                    .into_iter()
+                    .map(|right| {
+                        match right {
+                            Expr::Literal(right_lit_value) => {
+                                // if the right_lit_value can be casted to the 
type of internal_left_expr
+                                // we need to unwrap the cast for 
cast/try_cast expr, and add cast to the literal
+                                let Some(value) = 
try_cast_literal_to_type(&right_lit_value, &expr_type) else {
+                                    internal_err!(
+                                        "Can't cast the list expr {:?} to type 
{:?}",
+                                        right_lit_value, &expr_type
+                                    )?
+                                };
+                                Ok(lit(value))
+                            }
+                            other_expr => internal_err!(
+                                "Only support literal expr to optimize, but 
the expr is {:?}",
+                                &other_expr
+                            ),
+                        }
+                    })
+                    .collect::<Result<Vec<_>>>()?;
+
+                Transformed::yes(Expr::InList(InList {
+                    expr: std::mem::take(left_expr),
+                    list: right_exprs,
+                    negated,
+                }))
+            }
+
             // no additional rewrites possible
             expr => Transformed::no(expr),
         })
diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs 
b/datafusion/optimizer/src/simplify_expressions/mod.rs
index 46c066c11c..5fbee02e39 100644
--- a/datafusion/optimizer/src/simplify_expressions/mod.rs
+++ b/datafusion/optimizer/src/simplify_expressions/mod.rs
@@ -23,6 +23,7 @@ mod guarantees;
 mod inlist_simplifier;
 mod regex;
 pub mod simplify_exprs;
+mod unwrap_cast;
 mod utils;
 
 // backwards compatibility
diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs 
b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs
similarity index 79%
rename from datafusion/optimizer/src/unwrap_cast_in_comparison.rs
rename to datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs
index e2b8a966cb..7670bdf98b 100644
--- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
+++ b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs
@@ -15,274 +15,176 @@
 // specific language governing permissions and limitations
 // under the License.
 
-//! [`UnwrapCastInComparison`] rewrites `CAST(col) = lit` to `col = CAST(lit)`
+//! Unwrap casts in binary comparisons
+//!
+//! The functions in this module attempt to remove casts from
+//! comparisons to literals ([`ScalarValue`]s) by applying the casts
+//! to the literals if possible. It is inspired by the optimizer rule
+//! `UnwrapCastInBinaryComparison` of Spark.
+//!
+//! Removing casts often improves performance because:
+//! 1. The cast is done once (to the literal) rather than to every value
+//! 2. Can enable other optimizations such as predicate pushdown that
+//!    don't support casting
+//!
+//! The rule is applied to expressions of the following forms:
+//!
+//! 1. `cast(left_expr as data_type) comparison_op literal_expr`
+//! 2. `literal_expr comparison_op cast(left_expr as data_type)`
+//! 3. `cast(literal_expr) IN (expr1, expr2, ...)`
+//! 4. `literal_expr IN (cast(expr1) , cast(expr2), ...)`
+//!
+//! If the expression matches one of the forms above, the rule will
+//! ensure the value of `literal` is in range(min, max) of the
+//! expr's data_type, and if the scalar is within range, the literal
+//! will be casted to the data type of expr on the other side, and the
+//! cast will be removed from the other side.
+//!
+//! # Example
+//!
+//! If the DataType of c1 is INT32. Given the filter
+//!
+//! ```text
+//! cast(c1 as INT64) > INT64(10)`
+//! ```
+//!
+//! This rule will remove the cast and rewrite the expression to:
+//!
+//! ```text
+//! c1 > INT32(10)
+//! ```
+//!
 
 use std::cmp::Ordering;
-use std::mem;
-use std::sync::Arc;
 
-use crate::optimizer::ApplyOrder;
-use crate::{OptimizerConfig, OptimizerRule};
-
-use crate::utils::NamePreserver;
 use arrow::datatypes::{
     DataType, TimeUnit, MAX_DECIMAL128_FOR_EACH_PRECISION,
     MIN_DECIMAL128_FOR_EACH_PRECISION,
 };
 use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS};
-use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
-use datafusion_common::{internal_err, DFSchema, DFSchemaRef, Result, 
ScalarValue};
-use datafusion_expr::expr::{BinaryExpr, Cast, InList, TryCast};
-use datafusion_expr::utils::merge_schema;
-use datafusion_expr::{lit, Expr, ExprSchemable, LogicalPlan};
-
-/// [`UnwrapCastInComparison`] attempts to remove casts from
-/// comparisons to literals ([`ScalarValue`]s) by applying the casts
-/// to the literals if possible. It is inspired by the optimizer rule
-/// `UnwrapCastInBinaryComparison` of Spark.
-///
-/// Removing casts often improves performance because:
-/// 1. The cast is done once (to the literal) rather than to every value
-/// 2. Can enable other optimizations such as predicate pushdown that
-///    don't support casting
-///
-/// The rule is applied to expressions of the following forms:
-///
-/// 1. `cast(left_expr as data_type) comparison_op literal_expr`
-/// 2. `literal_expr comparison_op cast(left_expr as data_type)`
-/// 3. `cast(literal_expr) IN (expr1, expr2, ...)`
-/// 4. `literal_expr IN (cast(expr1) , cast(expr2), ...)`
-///
-/// If the expression matches one of the forms above, the rule will
-/// ensure the value of `literal` is in range(min, max) of the
-/// expr's data_type, and if the scalar is within range, the literal
-/// will be casted to the data type of expr on the other side, and the
-/// cast will be removed from the other side.
-///
-/// # Example
-///
-/// If the DataType of c1 is INT32. Given the filter
-///
-/// ```text
-/// Filter: cast(c1 as INT64) > INT64(10)`
-/// ```
-///
-/// This rule will remove the cast and rewrite the expression to:
-///
-/// ```text
-/// Filter: c1 > INT32(10)
-/// ```
-///
-#[derive(Default, Debug)]
-pub struct UnwrapCastInComparison {}
-
-impl UnwrapCastInComparison {
-    pub fn new() -> Self {
-        Self::default()
+use datafusion_common::{internal_err, tree_node::Transformed};
+use datafusion_common::{Result, ScalarValue};
+use datafusion_expr::{lit, BinaryExpr};
+use datafusion_expr::{simplify::SimplifyInfo, Cast, Expr, Operator, TryCast};
+
+pub(super) fn unwrap_cast_in_comparison_for_binary<S: SimplifyInfo>(
+    info: &S,
+    cast_expr: Box<Expr>,
+    literal: Box<Expr>,
+    op: Operator,
+) -> Result<Transformed<Expr>> {
+    match (*cast_expr, *literal) {
+        (
+            Expr::TryCast(TryCast { expr, .. }) | Expr::Cast(Cast { expr, .. 
}),
+            Expr::Literal(lit_value),
+        ) => {
+            let Ok(expr_type) = info.get_data_type(&expr) else {
+                return internal_err!("Can't get the data type of the expr 
{:?}", &expr);
+            };
+            // if the lit_value can be casted to the type of internal_left_expr
+            // we need to unwrap the cast for cast/try_cast expr, and add cast 
to the literal
+            let Some(value) = try_cast_literal_to_type(&lit_value, &expr_type) 
else {
+                return internal_err!(
+                    "Can't cast the literal expr {:?} to type {:?}",
+                    &lit_value,
+                    &expr_type
+                );
+            };
+            Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr {
+                left: expr,
+                op,
+                right: Box::new(lit(value)),
+            })))
+        }
+        _ => internal_err!("Expect cast expr and literal"),
     }
 }
 
-impl OptimizerRule for UnwrapCastInComparison {
-    fn name(&self) -> &str {
-        "unwrap_cast_in_comparison"
-    }
-
-    fn apply_order(&self) -> Option<ApplyOrder> {
-        Some(ApplyOrder::BottomUp)
-    }
+pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary<
+    S: SimplifyInfo,
+>(
+    info: &S,
+    expr: &Expr,
+    literal: &Expr,
+) -> bool {
+    match (expr, literal) {
+        (
+            Expr::TryCast(TryCast {
+                expr: left_expr, ..
+            })
+            | Expr::Cast(Cast {
+                expr: left_expr, ..
+            }),
+            Expr::Literal(lit_val),
+        ) => {
+            let Ok(expr_type) = info.get_data_type(left_expr) else {
+                return false;
+            };
 
-    fn supports_rewrite(&self) -> bool {
-        true
-    }
+            let Ok(lit_type) = info.get_data_type(literal) else {
+                return false;
+            };
 
-    fn rewrite(
-        &self,
-        plan: LogicalPlan,
-        _config: &dyn OptimizerConfig,
-    ) -> Result<Transformed<LogicalPlan>> {
-        let mut schema = merge_schema(&plan.inputs());
-
-        if let LogicalPlan::TableScan(ts) = &plan {
-            let source_schema = DFSchema::try_from_qualified_schema(
-                ts.table_name.clone(),
-                &ts.source.schema(),
-            )?;
-            schema.merge(&source_schema);
+            try_cast_literal_to_type(lit_val, &expr_type).is_some()
+                && is_supported_type(&expr_type)
+                && is_supported_type(&lit_type)
         }
+        _ => false,
+    }
+}
 
-        schema.merge(plan.schema());
+pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist<
+    S: SimplifyInfo,
+>(
+    info: &S,
+    expr: &Expr,
+    list: &[Expr],
+) -> bool {
+    let (Expr::TryCast(TryCast {
+        expr: left_expr, ..
+    })
+    | Expr::Cast(Cast {
+        expr: left_expr, ..
+    })) = expr
+    else {
+        return false;
+    };
 
-        let mut expr_rewriter = UnwrapCastExprRewriter {
-            schema: Arc::new(schema),
-        };
+    let Ok(expr_type) = info.get_data_type(left_expr) else {
+        return false;
+    };
 
-        let name_preserver = NamePreserver::new(&plan);
-        plan.map_expressions(|expr| {
-            let original_name = name_preserver.save(&expr);
-            expr.rewrite(&mut expr_rewriter)
-                .map(|transformed| transformed.update_data(|e| 
original_name.restore(e)))
-        })
+    if !is_supported_type(&expr_type) {
+        return false;
     }
-}
 
-struct UnwrapCastExprRewriter {
-    schema: DFSchemaRef,
-}
+    for right in list {
+        let Ok(right_type) = info.get_data_type(right) else {
+            return false;
+        };
 
-impl TreeNodeRewriter for UnwrapCastExprRewriter {
-    type Node = Expr;
-
-    fn f_up(&mut self, mut expr: Expr) -> Result<Transformed<Expr>> {
-        match &mut expr {
-            // For case:
-            // try_cast/cast(expr as data_type) op literal
-            // literal op try_cast/cast(expr as data_type)
-            Expr::BinaryExpr(BinaryExpr { left, op, right })
-                if {
-                    let Ok(left_type) = left.get_type(&self.schema) else {
-                        return Ok(Transformed::no(expr));
-                    };
-                    let Ok(right_type) = right.get_type(&self.schema) else {
-                        return Ok(Transformed::no(expr));
-                    };
-                    is_supported_type(&left_type)
-                        && is_supported_type(&right_type)
-                        && op.supports_propagation()
-                } =>
-            {
-                match (left.as_mut(), right.as_mut()) {
-                    (
-                        Expr::Literal(left_lit_value),
-                        Expr::TryCast(TryCast {
-                            expr: right_expr, ..
-                        })
-                        | Expr::Cast(Cast {
-                            expr: right_expr, ..
-                        }),
-                    ) => {
-                        // if the left_lit_value can be cast to the type of 
expr
-                        // we need to unwrap the cast for cast/try_cast expr, 
and add cast to the literal
-                        let Ok(expr_type) = right_expr.get_type(&self.schema) 
else {
-                            return Ok(Transformed::no(expr));
-                        };
-                        match expr_type {
-                            // 
https://github.com/apache/datafusion/issues/12180
-                            DataType::Utf8View => Ok(Transformed::no(expr)),
-                            _ => {
-                                let Some(value) =
-                                    try_cast_literal_to_type(left_lit_value, 
&expr_type)
-                                else {
-                                    return Ok(Transformed::no(expr));
-                                };
-                                **left = lit(value);
-                                // unwrap the cast/try_cast for the right expr
-                                **right = mem::take(right_expr);
-                                Ok(Transformed::yes(expr))
-                            }
-                        }
-                    }
-                    (
-                        Expr::TryCast(TryCast {
-                            expr: left_expr, ..
-                        })
-                        | Expr::Cast(Cast {
-                            expr: left_expr, ..
-                        }),
-                        Expr::Literal(right_lit_value),
-                    ) => {
-                        // if the right_lit_value can be cast to the type of 
expr
-                        // we need to unwrap the cast for cast/try_cast expr, 
and add cast to the literal
-                        let Ok(expr_type) = left_expr.get_type(&self.schema) 
else {
-                            return Ok(Transformed::no(expr));
-                        };
-                        match expr_type {
-                            // 
https://github.com/apache/datafusion/issues/12180
-                            DataType::Utf8View => Ok(Transformed::no(expr)),
-                            _ => {
-                                let Some(value) =
-                                    try_cast_literal_to_type(right_lit_value, 
&expr_type)
-                                else {
-                                    return Ok(Transformed::no(expr));
-                                };
-                                // unwrap the cast/try_cast for the left expr
-                                **left = mem::take(left_expr);
-                                **right = lit(value);
-                                Ok(Transformed::yes(expr))
-                            }
-                        }
-                    }
-                    _ => Ok(Transformed::no(expr)),
-                }
-            }
-            // For case:
-            // try_cast/cast(expr as left_type) in (expr1,expr2,expr3)
-            Expr::InList(InList {
-                expr: left, list, ..
-            }) => {
-                let (Expr::TryCast(TryCast {
-                    expr: left_expr, ..
-                })
-                | Expr::Cast(Cast {
-                    expr: left_expr, ..
-                })) = left.as_mut()
-                else {
-                    return Ok(Transformed::no(expr));
-                };
-                let Ok(expr_type) = left_expr.get_type(&self.schema) else {
-                    return Ok(Transformed::no(expr));
-                };
-                if !is_supported_type(&expr_type) {
-                    return Ok(Transformed::no(expr));
-                }
-                let Ok(right_exprs) = list
-                    .iter()
-                    .map(|right| {
-                        let right_type = right.get_type(&self.schema)?;
-                        if !is_supported_type(&right_type) {
-                            internal_err!(
-                                "The type of list expr {} is not supported",
-                                &right_type
-                            )?;
-                        }
-                        match right {
-                            Expr::Literal(right_lit_value) => {
-                                // if the right_lit_value can be casted to the 
type of internal_left_expr
-                                // we need to unwrap the cast for 
cast/try_cast expr, and add cast to the literal
-                                let Some(value) = 
try_cast_literal_to_type(right_lit_value, &expr_type) else {
-                                    internal_err!(
-                                        "Can't cast the list expr {:?} to type 
{:?}",
-                                        right_lit_value, &expr_type
-                                    )?
-                                };
-                                Ok(lit(value))
-                            }
-                            other_expr => internal_err!(
-                                "Only support literal expr to optimize, but 
the expr is {:?}",
-                                &other_expr
-                            ),
-                        }
-                    })
-                    .collect::<Result<Vec<_>>>() else {
-                    return Ok(Transformed::no(expr))
-                };
-                **left = mem::take(left_expr);
-                *list = right_exprs;
-                Ok(Transformed::yes(expr))
-            }
-            // TODO: handle other expr type and dfs visit them
-            _ => Ok(Transformed::no(expr)),
+        if !is_supported_type(&right_type) {
+            return false;
+        }
+
+        match right {
+            Expr::Literal(lit_val)
+                if try_cast_literal_to_type(lit_val, &expr_type).is_some() => 
{}
+            _ => return false,
         }
     }
+
+    true
 }
 
-/// Returns true if [UnwrapCastExprRewriter] supports this data type
+/// Returns true if unwrap_cast_in_comparison supports this data type
 fn is_supported_type(data_type: &DataType) -> bool {
     is_supported_numeric_type(data_type)
         || is_supported_string_type(data_type)
         || is_supported_dictionary_type(data_type)
 }
 
-/// Returns true if [[UnwrapCastExprRewriter]] support this numeric type
+/// Returns true if unwrap_cast_in_comparison support this numeric type
 fn is_supported_numeric_type(data_type: &DataType) -> bool {
     matches!(
         data_type,
@@ -299,7 +201,7 @@ fn is_supported_numeric_type(data_type: &DataType) -> bool {
     )
 }
 
-/// Returns true if [UnwrapCastExprRewriter] supports casting this value as a 
string
+/// Returns true if unwrap_cast_in_comparison supports casting this value as a 
string
 fn is_supported_string_type(data_type: &DataType) -> bool {
     matches!(
         data_type,
@@ -307,14 +209,14 @@ fn is_supported_string_type(data_type: &DataType) -> bool 
{
     )
 }
 
-/// Returns true if [UnwrapCastExprRewriter] supports casting this value as a 
dictionary
+/// Returns true if unwrap_cast_in_comparison supports casting this value as a 
dictionary
 fn is_supported_dictionary_type(data_type: &DataType) -> bool {
     matches!(data_type,
                     DataType::Dictionary(_, inner) if is_supported_type(inner))
 }
 
 /// Convert a literal value from one data type to another
-fn try_cast_literal_to_type(
+pub(super) fn try_cast_literal_to_type(
     lit_value: &ScalarValue,
     target_type: &DataType,
 ) -> Option<ScalarValue> {
@@ -540,13 +442,16 @@ fn cast_between_timestamp(from: &DataType, to: &DataType, 
value: i128) -> Option
 
 #[cfg(test)]
 mod tests {
-    use std::collections::HashMap;
-
     use super::*;
+    use std::collections::HashMap;
+    use std::sync::Arc;
 
+    use crate::simplify_expressions::ExprSimplifier;
     use arrow::compute::{cast_with_options, CastOptions};
     use arrow::datatypes::Field;
-    use datafusion_common::tree_node::TransformedResult;
+    use datafusion_common::{DFSchema, DFSchemaRef};
+    use datafusion_expr::execution_props::ExecutionProps;
+    use datafusion_expr::simplify::SimplifyContext;
     use datafusion_expr::{cast, col, in_list, try_cast};
 
     #[test]
@@ -587,9 +492,9 @@ mod tests {
         let expected = col("c1").lt(null_i32());
         assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected);
 
-        // cast(INT8(NULL), INT32) < INT32(12) => INT8(NULL) < INT8(12)
+        // cast(INT8(NULL), INT32) < INT32(12) => INT8(NULL) < INT8(12) => 
BOOL(NULL)
         let lit_lt_lit = cast(null_i8(), DataType::Int32).lt(lit(12i32));
-        let expected = null_i8().lt(lit(12i8));
+        let expected = null_bool();
         assert_eq!(optimize_test(lit_lt_lit, &schema), expected);
     }
 
@@ -623,7 +528,7 @@ mod tests {
         // Verify reversed argument order
         // arrow_cast('value', 'Dictionary<Int32, Utf8>') = cast(str1 as 
Dictionary<Int32, Utf8>) => Utf8('value1') = str1
         let expr_input = lit(dict.clone()).eq(cast(col("str1"), 
dict.data_type()));
-        let expected = lit("value").eq(col("str1"));
+        let expected = col("str1").eq(lit("value"));
         assert_eq!(optimize_test(expr_input, &schema), expected);
     }
 
@@ -740,15 +645,27 @@ mod tests {
     #[test]
     fn test_unwrap_list_cast_comparison() {
         let schema = expr_test_schema();
-        // INT32(C1) IN (INT32(12),INT64(24)) -> INT32(C1) IN 
(INT32(12),INT32(24))
-        let expr_lt =
-            cast(col("c1"), DataType::Int64).in_list(vec![lit(12i64), 
lit(24i64)], false);
-        let expected = col("c1").in_list(vec![lit(12i32), lit(24i32)], false);
+        // INT32(C1) IN (INT32(12),INT64(23),INT64(34),INT64(56),INT64(78)) ->
+        // INT32(C1) IN (INT32(12),INT32(23),INT32(34),INT32(56),INT32(78))
+        let expr_lt = cast(col("c1"), DataType::Int64).in_list(
+            vec![lit(12i64), lit(23i64), lit(34i64), lit(56i64), lit(78i64)],
+            false,
+        );
+        let expected = col("c1").in_list(
+            vec![lit(12i32), lit(23i32), lit(34i32), lit(56i32), lit(78i32)],
+            false,
+        );
         assert_eq!(optimize_test(expr_lt, &schema), expected);
-        // INT32(C2) IN (INT64(NULL),INT64(24)) -> INT32(C1) IN 
(INT32(12),INT32(24))
-        let expr_lt =
-            cast(col("c2"), DataType::Int32).in_list(vec![null_i32(), 
lit(14i32)], false);
-        let expected = col("c2").in_list(vec![null_i64(), lit(14i64)], false);
+        // INT32(C2) IN (INT64(NULL),INT64(24),INT64(34),INT64(56),INT64(78)) 
->
+        // INT32(C2) IN (INT32(NULL),INT32(24),INT32(34),INT32(56),INT32(78))
+        let expr_lt = cast(col("c2"), DataType::Int32).in_list(
+            vec![null_i32(), lit(24i32), lit(34i64), lit(56i64), lit(78i64)],
+            false,
+        );
+        let expected = col("c2").in_list(
+            vec![null_i64(), lit(24i64), lit(34i64), lit(56i64), lit(78i64)],
+            false,
+        );
 
         assert_eq!(optimize_test(expr_lt, &schema), expected);
 
@@ -774,10 +691,14 @@ mod tests {
         );
         assert_eq!(optimize_test(expr_lt, &schema), expected);
 
-        // cast(INT32(12), INT64) IN (.....)
-        let expr_lt = cast(lit(12i32), DataType::Int64)
-            .in_list(vec![lit(13i64), lit(12i64)], false);
-        let expected = lit(12i32).in_list(vec![lit(13i32), lit(12i32)], false);
+        // cast(INT32(12), INT64) IN (.....) =>
+        // INT64(12) IN (INT64(12),INT64(13),INT64(14),INT64(15),INT64(16))
+        // => true
+        let expr_lt = cast(lit(12i32), DataType::Int64).in_list(
+            vec![lit(12i64), lit(13i64), lit(14i64), lit(15i64), lit(16i64)],
+            false,
+        );
+        let expected = lit(true);
         assert_eq!(optimize_test(expr_lt, &schema), expected);
     }
 
@@ -815,8 +736,12 @@ mod tests {
         assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);
 
         // inlist for unsupported data type
-        let expr_input =
-            in_list(cast(col("c6"), DataType::Float64), vec![lit(0f64)], 
false);
+        let expr_input = in_list(
+            cast(col("c6"), DataType::Float64),
+            // need more literals to avoid rewriting to binary expr
+            vec![lit(0f64), lit(1f64), lit(2f64), lit(3f64), lit(4f64)],
+            false,
+        );
         assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);
     }
 
@@ -833,10 +758,12 @@ mod tests {
     }
 
     fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr {
-        let mut expr_rewriter = UnwrapCastExprRewriter {
-            schema: Arc::clone(schema),
-        };
-        expr.rewrite(&mut expr_rewriter).data().unwrap()
+        let props = ExecutionProps::new();
+        let simplifier = ExprSimplifier::new(
+            SimplifyContext::new(&props).with_schema(Arc::clone(schema)),
+        );
+
+        simplifier.simplify(expr).unwrap()
     }
 
     fn expr_test_schema() -> DFSchemaRef {
@@ -862,6 +789,10 @@ mod tests {
         )
     }
 
+    fn null_bool() -> Expr {
+        lit(ScalarValue::Boolean(None))
+    }
+
     fn null_i8() -> Expr {
         lit(ScalarValue::Int8(None))
     }
diff --git a/datafusion/sqllogictest/test_files/explain.slt 
b/datafusion/sqllogictest/test_files/explain.slt
index d32ddd1512..cab7308f6f 100644
--- a/datafusion/sqllogictest/test_files/explain.slt
+++ b/datafusion/sqllogictest/test_files/explain.slt
@@ -181,7 +181,6 @@ logical_plan after type_coercion SAME TEXT AS ABOVE
 analyzed_logical_plan SAME TEXT AS ABOVE
 logical_plan after eliminate_nested_union SAME TEXT AS ABOVE
 logical_plan after simplify_expressions SAME TEXT AS ABOVE
-logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE
 logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE
 logical_plan after eliminate_join SAME TEXT AS ABOVE
 logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE
@@ -200,13 +199,11 @@ logical_plan after push_down_limit SAME TEXT AS ABOVE
 logical_plan after push_down_filter SAME TEXT AS ABOVE
 logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE
 logical_plan after simplify_expressions SAME TEXT AS ABOVE
-logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE
 logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE
 logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE
 logical_plan after optimize_projections TableScan: simple_explain_test 
projection=[a, b, c]
 logical_plan after eliminate_nested_union SAME TEXT AS ABOVE
 logical_plan after simplify_expressions SAME TEXT AS ABOVE
-logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE
 logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE
 logical_plan after eliminate_join SAME TEXT AS ABOVE
 logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE
@@ -225,7 +222,6 @@ logical_plan after push_down_limit SAME TEXT AS ABOVE
 logical_plan after push_down_filter SAME TEXT AS ABOVE
 logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE
 logical_plan after simplify_expressions SAME TEXT AS ABOVE
-logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE
 logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE
 logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE
 logical_plan after optimize_projections SAME TEXT AS ABOVE


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to