This is an automated email from the ASF dual-hosted git repository.

liukun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new a1b2112c2 change pre_cast_lit_in_comparison to 
unwrap_cast_in_comparison (#3662)
a1b2112c2 is described below

commit a1b2112c219d4ec67a153de41ea894c68736e9f6
Author: Kun Liu <[email protected]>
AuthorDate: Fri Sep 30 22:35:50 2022 +0800

    change pre_cast_lit_in_comparison to unwrap_cast_in_comparison (#3662)
    
    * change pre_cast_lit_in_comparison to unwrap_cast_in_comparison
    
    * change some test case
---
 datafusion/core/src/execution/context.rs           |   4 +-
 datafusion/core/tests/sql/explain_analyze.rs       |  20 +-
 datafusion/optimizer/src/lib.rs                    |   2 +-
 ..._comparison.rs => unwrap_cast_in_comparison.rs} | 340 ++++++++++++---------
 datafusion/optimizer/tests/integration-test.rs     |   4 +-
 5 files changed, 208 insertions(+), 162 deletions(-)

diff --git a/datafusion/core/src/execution/context.rs 
b/datafusion/core/src/execution/context.rs
index ff0ccf835..2a805a5fc 100644
--- a/datafusion/core/src/execution/context.rs
+++ b/datafusion/core/src/execution/context.rs
@@ -110,10 +110,10 @@ use datafusion_expr::{TableSource, TableType};
 use datafusion_optimizer::decorrelate_where_exists::DecorrelateWhereExists;
 use datafusion_optimizer::decorrelate_where_in::DecorrelateWhereIn;
 use datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys;
-use 
datafusion_optimizer::pre_cast_lit_in_comparison::PreCastLitInComparisonExpressions;
 use 
datafusion_optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
 use datafusion_optimizer::scalar_subquery_to_join::ScalarSubqueryToJoin;
 use datafusion_optimizer::type_coercion::TypeCoercion;
+use datafusion_optimizer::unwrap_cast_in_comparison::UnwrapCastInComparison;
 use datafusion_sql::{
     parser::DFParser,
     planner::{ContextProvider, SqlToRel},
@@ -1466,9 +1466,9 @@ impl SessionState {
         }
 
         let mut rules: Vec<Arc<dyn OptimizerRule + Sync + Send>> = vec![
-            Arc::new(PreCastLitInComparisonExpressions::new()),
             Arc::new(TypeCoercion::new()),
             Arc::new(SimplifyExpressions::new()),
+            Arc::new(UnwrapCastInComparison::new()),
             Arc::new(DecorrelateWhereExists::new()),
             Arc::new(DecorrelateWhereIn::new()),
             Arc::new(ScalarSubqueryToJoin::new()),
diff --git a/datafusion/core/tests/sql/explain_analyze.rs 
b/datafusion/core/tests/sql/explain_analyze.rs
index fe51aedc8..7d09d9483 100644
--- a/datafusion/core/tests/sql/explain_analyze.rs
+++ b/datafusion/core/tests/sql/explain_analyze.rs
@@ -767,8 +767,6 @@ async fn test_physical_plan_display_indent_multi_children() 
{
 #[tokio::test]
 #[cfg_attr(tarpaulin, ignore)]
 async fn csv_explain() {
-    // TODO: https://github.com/apache/arrow-datafusion/issues/3622 refactor 
the `PreCastLitInComparisonExpressions`
-
     // This test uses the execute function that create full plan cycle: 
logical, optimized logical, and physical,
     // then execute the physical plan and return the final explain results
     let ctx = SessionContext::new();
@@ -779,23 +777,6 @@ async fn csv_explain() {
 
     // Note can't use `assert_batches_eq` as the plan needs to be
     // normalized for filenames and number of cores
-    let expected = vec![
-        vec![
-            "logical_plan",
-            "Projection: #aggregate_test_100.c1\
-             \n  Filter: CAST(#aggregate_test_100.c2 AS Int32) > Int32(10)\
-             \n    TableScan: aggregate_test_100 projection=[c1, c2], 
partial_filters=[CAST(#aggregate_test_100.c2 AS Int32) > Int32(10)]"
-        ],
-        vec!["physical_plan",
-             "ProjectionExec: expr=[c1@0 as c1]\
-              \n  CoalesceBatchesExec: target_batch_size=4096\
-              \n    FilterExec: CAST(c2@1 AS Int32) > 10\
-              \n      RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES)\
-              \n        CsvExec: 
files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, 
limit=None, projection=[c1, c2]\
-              \n"
-        ]];
-    assert_eq!(expected, actual);
-
     let expected = vec![
         vec![
             "logical_plan",
@@ -811,6 +792,7 @@ async fn csv_explain() {
               \n        CsvExec: 
files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, 
limit=None, projection=[c1, c2]\
               \n"
         ]];
+    assert_eq!(expected, actual);
 
     let sql = "explain SELECT c1 FROM aggregate_test_100 where c2 > 10";
     let actual = execute(&ctx, sql).await;
diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs
index bfb563436..879658c40 100644
--- a/datafusion/optimizer/src/lib.rs
+++ b/datafusion/optimizer/src/lib.rs
@@ -35,9 +35,9 @@ pub mod subquery_filter_to_join;
 pub mod type_coercion;
 pub mod utils;
 
-pub mod pre_cast_lit_in_comparison;
 pub mod rewrite_disjunctive_predicate;
 #[cfg(test)]
 pub mod test;
+pub mod unwrap_cast_in_comparison;
 
 pub use optimizer::{OptimizerConfig, OptimizerRule};
diff --git a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs 
b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
similarity index 60%
rename from datafusion/optimizer/src/pre_cast_lit_in_comparison.rs
rename to datafusion/optimizer/src/unwrap_cast_in_comparison.rs
index a6d915cf0..0d5665f29 100644
--- a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs
+++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
@@ -15,8 +15,9 @@
 // specific language governing permissions and limitations
 // under the License.
 
-//! Pre-cast literal binary comparison rule can be only used to the binary 
comparison expr.
-//! It can reduce adding the `Expr::Cast` to the expr instead of adding the 
`Expr::Cast` to literal expr.
+//! Unwrap-cast binary comparison rule can be used to the binary/inlist 
comparison expr now, and other type
+//! of expr can be added if needed.
+//! This rule can reduce adding the `Expr::Cast` the expr instead of adding 
the `Expr::Cast` to literal expr.
 use crate::{OptimizerConfig, OptimizerRule};
 use arrow::datatypes::{
     DataType, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION,
@@ -28,14 +29,14 @@ use datafusion_expr::{
     binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator,
 };
 
-/// The rule can be only used to the numeric binary comparison with literal 
expr, like below pattern:
-/// `left_expr comparison_op literal_expr` or `literal_expr comparison_op 
right_expr`.
-/// The data type of two sides must be signed numeric type now, and will 
support more data type later.
+/// The rule can be used to the numeric binary comparison with literal expr, 
like below pattern:
+/// `cast(left_expr as data_type) comparison_op literal_expr` or `literal_expr 
comparison_op cast(right_expr as data_type)`.
+/// The data type of two sides must be equal, and must be signed numeric type 
now, and will support more data type later.
 ///
 /// If the binary comparison expr match above rules, the optimizer will check 
if the value of `literal`
 /// is in within range(min,max) which is the range(min,max) of the data type 
for `left_expr` or `right_expr`.
 ///
-/// If this true, the literal expr will be casted to the data type of expr on 
the other side, and the result of
+/// If this is true, the literal expr will be casted to the data type of expr 
on the other side, and the result of
 /// binary comparison will be `left_expr comparison_op cast(literal_expr, 
left_data_type)` or
 /// `cast(literal_expr, right_data_type) comparison_op right_expr`. For better 
optimization,
 /// the expr of `cast(literal_expr, target_type)` will be precomputed and 
converted to the new expr `new_literal_expr`
@@ -45,19 +46,19 @@ use datafusion_expr::{
 /// This is inspired by the optimizer rule `UnwrapCastInBinaryComparison` of 
Spark.
 /// # Example
 ///
-/// `Filter: c1 > INT64(10)` will be optimized to `Filter: c1 > CAST(INT64(10) 
AS INT32),
+/// `Filter: cast(c1 as INT64) > INT64(10)` will be optimized to `Filter: c1 > 
CAST(INT64(10) AS INT32),
 /// and continue to be converted to `Filter: c1 > INT32(10)`, if the DataType 
of c1 is INT32.
 ///
 #[derive(Default)]
-pub struct PreCastLitInComparisonExpressions {}
+pub struct UnwrapCastInComparison {}
 
-impl PreCastLitInComparisonExpressions {
+impl UnwrapCastInComparison {
     pub fn new() -> Self {
         Self::default()
     }
 }
 
-impl OptimizerRule for PreCastLitInComparisonExpressions {
+impl OptimizerRule for UnwrapCastInComparison {
     fn optimize(
         &self,
         plan: &LogicalPlan,
@@ -67,7 +68,7 @@ impl OptimizerRule for PreCastLitInComparisonExpressions {
     }
 
     fn name(&self) -> &str {
-        "pre_cast_lit_in_comparison"
+        "unwrap_cast_in_comparison"
     }
 }
 
@@ -80,7 +81,7 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
 
     let schema = plan.schema();
 
-    let mut expr_rewriter = PreCastLitExprRewriter {
+    let mut expr_rewriter = UnwrapCastExprRewriter {
         schema: schema.clone(),
     };
 
@@ -93,17 +94,20 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
     from_plan(plan, new_exprs.as_slice(), new_inputs.as_slice())
 }
 
-struct PreCastLitExprRewriter {
+struct UnwrapCastExprRewriter {
     schema: DFSchemaRef,
 }
 
-impl ExprRewriter for PreCastLitExprRewriter {
+impl ExprRewriter for UnwrapCastExprRewriter {
     fn pre_visit(&mut self, _expr: &Expr) -> Result<RewriteRecursion> {
         Ok(RewriteRecursion::Continue)
     }
 
     fn mutate(&mut self, expr: Expr) -> Result<Expr> {
         match &expr {
+            // For case:
+            // try_cast/cast(expr as data_type) op literal
+            // literal op try_cast/cast(expr as data_type)
             Expr::BinaryExpr { left, op, right } => {
                 let left = left.as_ref().clone();
                 let right = right.as_ref().clone();
@@ -113,29 +117,48 @@ impl ExprRewriter for PreCastLitExprRewriter {
                 if left_type.is_err() || right_type.is_err() {
                     return Ok(expr.clone());
                 }
+                // Because the plan has been done the type coercion, the left 
and right must be equal
                 let left_type = left_type?;
                 let right_type = right_type?;
-                if !left_type.eq(&right_type)
-                    && is_support_data_type(&left_type)
+                if is_support_data_type(&left_type)
                     && is_support_data_type(&right_type)
                     && is_comparison_op(op)
                 {
                     match (&left, &right) {
-                        (Expr::Literal(_), Expr::Literal(_)) => {
-                            // do nothing
-                        }
-                        (Expr::Literal(left_lit_value), _) => {
+                        (
+                            Expr::Literal(left_lit_value),
+                            Expr::TryCast { expr, .. } | Expr::Cast { expr, .. 
},
+                        ) => {
+                            // if the left_lit_value can be casted to the type 
of expr
+                            // we need to unwrap the cast for cast/try_cast 
expr, and add cast to the literal
+                            let expr_type = expr.get_type(&self.schema)?;
                             let casted_scalar_value =
-                                try_cast_literal_to_type(left_lit_value, 
&right_type)?;
+                                try_cast_literal_to_type(left_lit_value, 
&expr_type)?;
                             if let Some(value) = casted_scalar_value {
-                                return Ok(binary_expr(lit(value), *op, right));
+                                // unwrap the cast/try_cast for the right expr
+                                return Ok(binary_expr(
+                                    lit(value),
+                                    *op,
+                                    expr.as_ref().clone(),
+                                ));
                             }
                         }
-                        (_, Expr::Literal(right_lit_value)) => {
+                        (
+                            Expr::TryCast { expr, .. } | Expr::Cast { expr, .. 
},
+                            Expr::Literal(right_lit_value),
+                        ) => {
+                            // if the right_lit_value can be casted to the 
type of expr
+                            // we need to unwrap the cast for cast/try_cast 
expr, and add cast to the literal
+                            let expr_type = expr.get_type(&self.schema)?;
                             let casted_scalar_value =
-                                try_cast_literal_to_type(right_lit_value, 
&left_type)?;
+                                try_cast_literal_to_type(right_lit_value, 
&expr_type)?;
                             if let Some(value) = casted_scalar_value {
-                                return Ok(binary_expr(left, *op, lit(value)));
+                                // unwrap the cast/try_cast for the left expr
+                                return Ok(binary_expr(
+                                    expr.as_ref().clone(),
+                                    *op,
+                                    lit(value),
+                                ));
                             }
                         }
                         (_, _) => {
@@ -146,55 +169,75 @@ impl ExprRewriter for PreCastLitExprRewriter {
                 // return the new binary op
                 Ok(binary_expr(left, *op, right))
             }
+            // For case:
+            // try_cast/cast(expr as left_type) in (expr1,expr2,expr3)
             Expr::InList {
                 expr: left_expr,
                 list,
                 negated,
             } => {
-                let left = left_expr.as_ref().clone();
-                let left_type = left.get_type(&self.schema);
-                if left_type.is_err() {
-                    // error data type
-                    return Ok(expr);
-                }
-                let left_type = left_type?;
-                if !is_support_data_type(&left_type) {
-                    // not supported data type
-                    return Ok(expr);
-                }
-                let right_exprs = list
-                    .iter()
-                    .map(|right| {
-                        let right_type = right.get_type(&self.schema)?;
-                        if !is_support_data_type(&right_type) {
-                            return Err(DataFusionError::Internal(format!(
-                                "The type of list expr {} not support",
-                                &right_type
-                            )));
-                        }
-                        match right {
-                            Expr::Literal(right_lit_value) => {
-                                let casted_scalar_value =
-                                    try_cast_literal_to_type(right_lit_value, 
&left_type)?;
-                                if let Some(value) = casted_scalar_value {
-                                    Ok(lit(value))
-                                } else {
-                                    Err(DataFusionError::Internal(format!(
-                                        "Can't cast the list expr {:?} to type 
{:?}",
-                                        right_lit_value, &left_type
-                                    )))
+                if let Some(
+                    Expr::TryCast {
+                        expr: internal_left_expr,
+                        ..
+                    }
+                    | Expr::Cast {
+                        expr: internal_left_expr,
+                        ..
+                    },
+                ) = Some(left_expr.as_ref())
+                {
+                    let internal_left = internal_left_expr.as_ref().clone();
+                    let internal_left_type = 
internal_left.get_type(&self.schema);
+                    if internal_left_type.is_err() {
+                        // error data type
+                        return Ok(expr);
+                    }
+                    let internal_left_type = internal_left_type?;
+                    if !is_support_data_type(&internal_left_type) {
+                        // not supported data type
+                        return Ok(expr);
+                    }
+                    let right_exprs = list
+                        .iter()
+                        .map(|right| {
+                            let right_type = right.get_type(&self.schema)?;
+                            if !is_support_data_type(&right_type) {
+                                return Err(DataFusionError::Internal(format!(
+                                    "The type of list expr {} not support",
+                                    &right_type
+                                )));
+                            }
+                            match right {
+                                Expr::Literal(right_lit_value) => {
+                                    // if the right_lit_value can be casted to 
the type of internal_left_expr
+                                    // we need to unwrap the cast for 
cast/try_cast expr, and add cast to the literal
+                                    let casted_scalar_value =
+                                        
try_cast_literal_to_type(right_lit_value, &internal_left_type)?;
+                                    if let Some(value) = casted_scalar_value {
+                                        Ok(lit(value))
+                                    } else {
+                                        Err(DataFusionError::Internal(format!(
+                                            "Can't cast the list expr {:?} to 
type {:?}",
+                                            right_lit_value, 
&internal_left_type
+                                        )))
+                                    }
                                 }
+                                other_expr => 
Err(DataFusionError::Internal(format!(
+                                    "Only support literal expr to optimize, 
but the expr is {:?}",
+                                    &other_expr
+                                ))),
                             }
-                            other_expr => 
Err(DataFusionError::Internal(format!(
-                                "Only support literal expr to optimize, but 
the expr is {:?}",
-                                &other_expr
-                            ))),
+                        })
+                        .collect::<Result<Vec<_>>>();
+                    match right_exprs {
+                        Ok(right_exprs) => {
+                            Ok(in_list(internal_left, right_exprs, *negated))
                         }
-                    })
-                    .collect::<Result<Vec<_>>>();
-                match right_exprs {
-                    Ok(right_exprs) => Ok(in_list(left, right_exprs, 
*negated)),
-                    Err(_) => Ok(expr),
+                        Err(_) => Ok(expr),
+                    }
+                } else {
+                    Ok(expr)
                 }
             }
             // TODO: handle other expr type and dfs visit them
@@ -326,23 +369,19 @@ fn try_cast_literal_to_type(
 
 #[cfg(test)]
 mod tests {
-    use crate::pre_cast_lit_in_comparison::PreCastLitExprRewriter;
+    use crate::unwrap_cast_in_comparison::UnwrapCastExprRewriter;
     use arrow::datatypes::DataType;
     use datafusion_common::{DFField, DFSchema, DFSchemaRef, ScalarValue};
     use datafusion_expr::expr_rewriter::ExprRewritable;
-    use datafusion_expr::{col, lit, Expr};
+    use datafusion_expr::{cast, col, lit, try_cast, Expr};
     use std::collections::HashMap;
     use std::sync::Arc;
 
     #[test]
-    fn test_not_cast_lit_comparison() {
+    fn test_not_unwrap_cast_comparison() {
         let schema = expr_test_schema();
-        // INT8(NULL) < INT32(12)
-        let lit_lt_lit =
-            lit(ScalarValue::Int8(None)).lt(lit(ScalarValue::Int32(Some(12))));
-        assert_eq!(optimize_test(lit_lt_lit.clone(), &schema), lit_lt_lit);
-        // INT32(c1) > INT64(c2)
-        let c1_gt_c2 = col("c1").gt(col("c2"));
+        // cast(INT32(c1), INT64) > INT64(c2)
+        let c1_gt_c2 = cast(col("c1"), DataType::Int64).gt(col("c2"));
         assert_eq!(optimize_test(c1_gt_c2.clone(), &schema), c1_gt_c2);
 
         // INT32(c1) < INT32(16), the type is same
@@ -350,110 +389,132 @@ mod tests {
         assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
 
         // the 99999999999 is not within the range of MAX(int32) and 
MIN(int32), we don't cast the lit(99999999999) to int32 type
-        let expr_lt = col("c1").lt(lit(ScalarValue::Int64(Some(99999999999))));
+        let expr_lt = cast(col("c1"), DataType::Int64)
+            .lt(lit(ScalarValue::Int64(Some(99999999999))));
         assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
     }
 
     #[test]
-    fn test_pre_cast_lit_comparison() {
+    fn test_unwrap_cast_comparison() {
         let schema = expr_test_schema();
-        // c1 < INT64(16) -> c1 < cast(INT32(16))
+        // cast(c1, INT64) < INT64(16) -> INT32(c1) < cast(INT32(16))
         // the 16 is within the range of MAX(int32) and MIN(int32), we can 
cast the 16 to int32(16)
-        let expr_lt = col("c1").lt(lit(ScalarValue::Int64(Some(16))));
+        let expr_lt =
+            cast(col("c1"), 
DataType::Int64).lt(lit(ScalarValue::Int64(Some(16))));
+        let expected = col("c1").lt(lit(ScalarValue::Int32(Some(16))));
+        assert_eq!(optimize_test(expr_lt, &schema), expected);
+        let expr_lt =
+            try_cast(col("c1"), 
DataType::Int64).lt(lit(ScalarValue::Int64(Some(16))));
         let expected = col("c1").lt(lit(ScalarValue::Int32(Some(16))));
         assert_eq!(optimize_test(expr_lt, &schema), expected);
 
-        // INT64(c2) = INT32(16) => INT64(c2) = INT64(16)
-        let c2_eq_lit = col("c2").eq(lit(ScalarValue::Int32(Some(16))));
+        // cast(c2, INT32) = INT32(16) => INT64(c2) = INT64(16)
+        let c2_eq_lit =
+            cast(col("c2"), 
DataType::Int32).eq(lit(ScalarValue::Int32(Some(16))));
         let expected = col("c2").eq(lit(ScalarValue::Int64(Some(16))));
         assert_eq!(optimize_test(c2_eq_lit, &schema), expected);
 
-        // INT32(c1) < INT64(NULL) => INT32(c1) < INT32(NULL)
-        let c1_lt_lit_null = col("c1").lt(lit(ScalarValue::Int64(None)));
+        // cast(c1, INT64) < INT64(NULL) => INT32(c1) < INT32(NULL)
+        let c1_lt_lit_null =
+            cast(col("c1"), DataType::Int64).lt(lit(ScalarValue::Int64(None)));
         let expected = col("c1").lt(lit(ScalarValue::Int32(None)));
         assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected);
+
+        // cast(INT8(NULL), INT32) < INT32(12) => INT8(NULL) < INT8(12)
+        let lit_lt_lit = cast(lit(ScalarValue::Int8(None)), DataType::Int32)
+            .lt(lit(ScalarValue::Int32(Some(12))));
+        let expected = 
lit(ScalarValue::Int8(None)).lt(lit(ScalarValue::Int8(Some(12))));
+        assert_eq!(optimize_test(lit_lt_lit, &schema), expected);
     }
 
     #[test]
-    fn test_not_cast_with_decimal_lit_comparison() {
+    fn test_not_unwrap_cast_with_decimal_comparison() {
         let schema = expr_test_schema();
         // integer to decimal: value is out of the bounds of the decimal
-        // c3 = INT64(100000000000000000)
-        let expr_eq = 
col("c3").eq(lit(ScalarValue::Int64(Some(100000000000000000))));
-        let expected = 
col("c3").eq(lit(ScalarValue::Int64(Some(100000000000000000))));
-        assert_eq!(optimize_test(expr_eq, &schema), expected);
-        // c4 = INT64(1000) will overflow the i128
-        let expr_eq = col("c4").eq(lit(ScalarValue::Int64(Some(1000))));
-        let expected = col("c4").eq(lit(ScalarValue::Int64(Some(1000))));
-        assert_eq!(optimize_test(expr_eq, &schema), expected);
+        // cast(c3, INT64) = INT64(100000000000000000)
+        let expr_eq = cast(col("c3"), DataType::Int64)
+            .eq(lit(ScalarValue::Int64(Some(100000000000000000))));
+        assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq);
+
+        // cast(c4, INT64) = INT64(1000) will overflow the i128
+        let expr_eq =
+            cast(col("c4"), 
DataType::Int64).eq(lit(ScalarValue::Int64(Some(1000))));
+        assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq);
 
         // decimal to decimal: value will lose the scale when convert to the 
target data type
         // c3 = DECIMAL(12340,20,4)
-        let expr_eq = col("c3").eq(lit(ScalarValue::Decimal128(Some(12340), 
20, 4)));
-        let expected = col("c3").eq(lit(ScalarValue::Decimal128(Some(12340), 
20, 4)));
-        assert_eq!(optimize_test(expr_eq, &schema), expected);
+        let expr_eq = cast(col("c3"), DataType::Decimal128(20, 4))
+            .eq(lit(ScalarValue::Decimal128(Some(12340), 20, 4)));
+        assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq);
 
         // decimal to integer
         // c1 = DECIMAL(123, 10, 1): value will lose the scale when convert to 
the target data type
-        let expr_eq = col("c1").eq(lit(ScalarValue::Decimal128(Some(123), 10, 
1)));
-        let expected = col("c1").eq(lit(ScalarValue::Decimal128(Some(123), 10, 
1)));
-        assert_eq!(optimize_test(expr_eq, &schema), expected);
+        let expr_eq = cast(col("c1"), DataType::Decimal128(10, 1))
+            .eq(lit(ScalarValue::Decimal128(Some(123), 10, 1)));
+        assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq);
+
         // c1 = DECIMAL(1230, 10, 2): value will lose the scale when convert 
to the target data type
-        let expr_eq = col("c1").eq(lit(ScalarValue::Decimal128(Some(1230), 10, 
2)));
-        let expected = col("c1").eq(lit(ScalarValue::Decimal128(Some(1230), 
10, 2)));
-        assert_eq!(optimize_test(expr_eq, &schema), expected);
+        let expr_eq = cast(col("c1"), DataType::Decimal128(10, 2))
+            .eq(lit(ScalarValue::Decimal128(Some(1230), 10, 2)));
+        assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq);
     }
 
     #[test]
-    fn test_pre_cast_with_decimal_lit_comparison() {
+    fn test_unwrap_cast_with_decimal_lit_comparison() {
         let schema = expr_test_schema();
         // integer to decimal
         // c3 < INT64(16) -> c3 < (CAST(INT64(16) AS DECIMAL(18,2));
-        let expr_lt = col("c3").lt(lit(ScalarValue::Int64(Some(16))));
+        let expr_lt =
+            try_cast(col("c3"), 
DataType::Int64).lt(lit(ScalarValue::Int64(Some(16))));
         let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(1600), 
18, 2)));
         assert_eq!(optimize_test(expr_lt, &schema), expected);
 
         // c3 < INT64(NULL)
-        let c1_lt_lit_null = col("c3").lt(lit(ScalarValue::Int64(None)));
+        let c1_lt_lit_null =
+            cast(col("c3"), DataType::Int64).lt(lit(ScalarValue::Int64(None)));
         let expected = col("c3").lt(lit(ScalarValue::Decimal128(None, 18, 2)));
         assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected);
 
         // decimal to decimal
         // c3 < Decimal(123,10,0) -> c3 < CAST(DECIMAL(123,10,0) AS 
DECIMAL(18,2)) -> c3 < DECIMAL(12300,18,2)
-        let expr_lt = col("c3").lt(lit(ScalarValue::Decimal128(Some(123), 10, 
0)));
+        let expr_lt = cast(col("c3"), DataType::Decimal128(10, 0))
+            .lt(lit(ScalarValue::Decimal128(Some(123), 10, 0)));
         let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(12300), 
18, 2)));
         assert_eq!(optimize_test(expr_lt, &schema), expected);
+
         // c3 < Decimal(1230,10,3) -> c3 < CAST(DECIMAL(1230,10,3) AS 
DECIMAL(18,2)) -> c3 < DECIMAL(123,18,2)
-        let expr_lt = col("c3").lt(lit(ScalarValue::Decimal128(Some(1230), 10, 
3)));
+        let expr_lt = cast(col("c3"), DataType::Decimal128(10, 3))
+            .lt(lit(ScalarValue::Decimal128(Some(1230), 10, 3)));
         let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(123), 18, 
2)));
         assert_eq!(optimize_test(expr_lt, &schema), expected);
 
         // decimal to integer
         // c1 < Decimal(12300, 10, 2) -> c1 < CAST(DECIMAL(12300,10,2) AS 
INT32) -> c1 < INT32(123)
-        let expr_lt = col("c1").lt(lit(ScalarValue::Decimal128(Some(12300), 
10, 2)));
+        let expr_lt = cast(col("c1"), DataType::Decimal128(10, 2))
+            .lt(lit(ScalarValue::Decimal128(Some(12300), 10, 2)));
         let expected = col("c1").lt(lit(ScalarValue::Int32(Some(123))));
         assert_eq!(optimize_test(expr_lt, &schema), expected);
     }
 
     #[test]
-    fn test_not_list_cast_lit_comparison() {
+    fn test_not_unwrap_list_cast_lit_comparison() {
         let schema = expr_test_schema();
-        // left type is not supported
+        // internal left type is not supported
         // FLOAT32(C5) in ...
-        let expr_lt = col("c5").in_list(
+        let expr_lt = cast(col("c5"), DataType::Int64).in_list(
             vec![
                 lit(ScalarValue::Int64(Some(12))),
-                lit(ScalarValue::Int32(Some(12))),
+                lit(ScalarValue::Int64(Some(12))),
             ],
             false,
         );
         assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
 
-        // INT32(C1) in (FLOAT32(1.23), INT32(12), INT64(12))
-        let expr_lt = col("c1").in_list(
+        // cast(INT32(C1), Float32) in (FLOAT32(1.23), Float32(12), 
Float32(12))
+        let expr_lt = cast(col("c1"), DataType::Float32).in_list(
             vec![
-                lit(ScalarValue::Int32(Some(12))),
-                lit(ScalarValue::Int64(Some(12))),
+                lit(ScalarValue::Float32(Some(12.0))),
+                lit(ScalarValue::Float32(Some(12.0))),
                 lit(ScalarValue::Float32(Some(1.23))),
             ],
             false,
@@ -461,7 +522,7 @@ mod tests {
         assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
 
         // INT32(C1) in (INT64(99999999999), INT64(12))
-        let expr_lt = col("c1").in_list(
+        let expr_lt = cast(col("c1"), DataType::Int64).in_list(
             vec![
                 lit(ScalarValue::Int32(Some(12))),
                 lit(ScalarValue::Int64(Some(99999999999))),
@@ -471,10 +532,10 @@ mod tests {
         assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
 
         // DECIMAL(C3) in (INT64(12), INT32(12), DECIMAL(128,12,3))
-        let expr_lt = col("c3").in_list(
+        let expr_lt = cast(col("c3"), DataType::Decimal128(12, 3)).in_list(
             vec![
-                lit(ScalarValue::Int32(Some(12))),
-                lit(ScalarValue::Int64(Some(12))),
+                lit(ScalarValue::Decimal128(Some(12), 12, 3)),
+                lit(ScalarValue::Decimal128(Some(12), 12, 3)),
                 lit(ScalarValue::Decimal128(Some(128), 12, 3)),
             ],
             false,
@@ -483,12 +544,12 @@ mod tests {
     }
 
     #[test]
-    fn test_pre_list_cast_lit_comparison() {
+    fn test_unwrap_list_cast_comparison() {
         let schema = expr_test_schema();
         // INT32(C1) IN (INT32(12),INT64(24)) -> INT32(C1) IN 
(INT32(12),INT32(24))
-        let expr_lt = col("c1").in_list(
+        let expr_lt = cast(col("c1"), DataType::Int64).in_list(
             vec![
-                lit(ScalarValue::Int32(Some(12))),
+                lit(ScalarValue::Int64(Some(12))),
                 lit(ScalarValue::Int64(Some(24))),
             ],
             false,
@@ -502,9 +563,9 @@ mod tests {
         );
         assert_eq!(optimize_test(expr_lt, &schema), expected);
         // INT32(C2) IN (INT64(NULL),INT64(24)) -> INT32(C1) IN 
(INT32(12),INT32(24))
-        let expr_lt = col("c2").in_list(
+        let expr_lt = cast(col("c2"), DataType::Int32).in_list(
             vec![
-                lit(ScalarValue::Int64(None)),
+                lit(ScalarValue::Int32(None)),
                 lit(ScalarValue::Int32(Some(14))),
             ],
             false,
@@ -520,12 +581,13 @@ mod tests {
         assert_eq!(optimize_test(expr_lt, &schema), expected);
 
         // decimal test case
-        let expr_lt = col("c3").in_list(
+        // c3 is decimal(18,2)
+        let expr_lt = cast(col("c3"), DataType::Decimal128(19, 3)).in_list(
             vec![
-                lit(ScalarValue::Int32(Some(12))),
-                lit(ScalarValue::Int64(Some(24))),
-                lit(ScalarValue::Decimal128(Some(128), 10, 2)),
-                lit(ScalarValue::Decimal128(Some(1280), 10, 3)),
+                lit(ScalarValue::Decimal128(Some(12000), 19, 3)),
+                lit(ScalarValue::Decimal128(Some(24000), 19, 3)),
+                lit(ScalarValue::Decimal128(Some(1280), 19, 3)),
+                lit(ScalarValue::Decimal128(Some(1240), 19, 3)),
             ],
             false,
         );
@@ -534,23 +596,23 @@ mod tests {
                 lit(ScalarValue::Decimal128(Some(1200), 18, 2)),
                 lit(ScalarValue::Decimal128(Some(2400), 18, 2)),
                 lit(ScalarValue::Decimal128(Some(128), 18, 2)),
-                lit(ScalarValue::Decimal128(Some(128), 18, 2)),
+                lit(ScalarValue::Decimal128(Some(124), 18, 2)),
             ],
             false,
         );
         assert_eq!(optimize_test(expr_lt, &schema), expected);
 
-        // INT32(12) IN (.....)
-        let expr_lt = lit(ScalarValue::Int32(Some(12))).in_list(
+        // cast(INT32(12), INT64) IN (.....)
+        let expr_lt = cast(lit(ScalarValue::Int32(Some(12))), 
DataType::Int64).in_list(
             vec![
-                lit(ScalarValue::Int32(Some(12))),
+                lit(ScalarValue::Int64(Some(13))),
                 lit(ScalarValue::Int64(Some(12))),
             ],
             false,
         );
         let expected = lit(ScalarValue::Int32(Some(12))).in_list(
             vec![
-                lit(ScalarValue::Int32(Some(12))),
+                lit(ScalarValue::Int32(Some(13))),
                 lit(ScalarValue::Int32(Some(12))),
             ],
             false,
@@ -563,7 +625,9 @@ mod tests {
         let schema = expr_test_schema();
         // c1 < INT64(16) -> c1 < cast(INT32(16))
         // the 16 is within the range of MAX(int32) and MIN(int32), we can 
cast the 16 to int32(16)
-        let expr_lt = 
col("c1").lt(lit(ScalarValue::Int64(Some(16)))).alias("x");
+        let expr_lt = cast(col("c1"), DataType::Int64)
+            .lt(lit(ScalarValue::Int64(Some(16))))
+            .alias("x");
         let expected = 
col("c1").lt(lit(ScalarValue::Int32(Some(16)))).alias("x");
         assert_eq!(optimize_test(expr_lt, &schema), expected);
     }
@@ -573,9 +637,9 @@ mod tests {
         let schema = expr_test_schema();
         // c1 < INT64(16) OR c1 > INT64(32) -> c1 < INT32(16) OR c1 > INT32(32)
         // the 16 and 32 are within the range of MAX(int32) and MIN(int32), we 
can cast them to int32
-        let expr_lt = col("c1")
+        let expr_lt = cast(col("c1"), DataType::Int64)
             .lt(lit(ScalarValue::Int64(Some(16))))
-            .or(col("c1").gt(lit(ScalarValue::Int64(Some(32)))));
+            .or(cast(col("c1"), 
DataType::Int64).gt(lit(ScalarValue::Int64(Some(32)))));
         let expected = col("c1")
             .lt(lit(ScalarValue::Int32(Some(16))))
             .or(col("c1").gt(lit(ScalarValue::Int32(Some(32)))));
@@ -583,7 +647,7 @@ mod tests {
     }
 
     fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr {
-        let mut expr_rewriter = PreCastLitExprRewriter {
+        let mut expr_rewriter = UnwrapCastExprRewriter {
             schema: schema.clone(),
         };
         expr.rewrite(&mut expr_rewriter).unwrap()
diff --git a/datafusion/optimizer/tests/integration-test.rs 
b/datafusion/optimizer/tests/integration-test.rs
index 61bfafed7..7811e475c 100644
--- a/datafusion/optimizer/tests/integration-test.rs
+++ b/datafusion/optimizer/tests/integration-test.rs
@@ -27,7 +27,6 @@ use 
datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys;
 use datafusion_optimizer::filter_push_down::FilterPushDown;
 use datafusion_optimizer::limit_push_down::LimitPushDown;
 use datafusion_optimizer::optimizer::Optimizer;
-use 
datafusion_optimizer::pre_cast_lit_in_comparison::PreCastLitInComparisonExpressions;
 use datafusion_optimizer::projection_push_down::ProjectionPushDown;
 use datafusion_optimizer::reduce_cross_join::ReduceCrossJoin;
 use datafusion_optimizer::reduce_outer_join::ReduceOuterJoin;
@@ -37,6 +36,7 @@ use 
datafusion_optimizer::simplify_expressions::SimplifyExpressions;
 use datafusion_optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy;
 use datafusion_optimizer::subquery_filter_to_join::SubqueryFilterToJoin;
 use datafusion_optimizer::type_coercion::TypeCoercion;
+use datafusion_optimizer::unwrap_cast_in_comparison::UnwrapCastInComparison;
 use datafusion_optimizer::{OptimizerConfig, OptimizerRule};
 use datafusion_sql::planner::{ContextProvider, SqlToRel};
 use datafusion_sql::sqlparser::ast::Statement;
@@ -107,9 +107,9 @@ fn test_sql(sql: &str) -> Result<LogicalPlan> {
     // TODO should make align with rules in the context
     // https://github.com/apache/arrow-datafusion/issues/3524
     let rules: Vec<Arc<dyn OptimizerRule + Sync + Send>> = vec![
-        Arc::new(PreCastLitInComparisonExpressions::new()),
         Arc::new(TypeCoercion::new()),
         Arc::new(SimplifyExpressions::new()),
+        Arc::new(UnwrapCastInComparison::new()),
         Arc::new(DecorrelateWhereExists::new()),
         Arc::new(DecorrelateWhereIn::new()),
         Arc::new(ScalarSubqueryToJoin::new()),


Reply via email to