This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 26c0c49ad1 perf: unwrap cast for comparing ints =/!= strings (#15110)
26c0c49ad1 is described below

commit 26c0c49ad1dd91e6cbb10b16d81faa3a6f361d81
Author: Li-Lun Lin <[email protected]>
AuthorDate: Fri Mar 28 03:41:33 2025 +0800

    perf: unwrap cast for comparing ints =/!= strings (#15110)
    
    * perf: unwrap cast for comparing ints =/!= strings
    
    * fix: update casting logic
    
    * test: add more unit test and new sqllogictest
    
    * Tweak slt tests
    
    * Revert "perf: unwrap cast for comparing ints =/!= strings"
    
    This reverts commit 808d6ab3ceb0281d055965a330b8ffb1c47fa65b.
    
    * fix: eliminate column cast and cast literal before coercion
    
    * fix: physical expr coercion test
    
    * feat: unwrap cast after round-trip cast verification
    
    * fix: unwrap cast on round-trip cast stable strings
    
    * revert: remove avoid cast changes
    
    * refactor: apply review suggestions
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/optimizer/src/analyzer/type_coercion.rs |   1 +
 .../src/simplify_expressions/expr_simplifier.rs    |   4 +-
 .../src/simplify_expressions/unwrap_cast.rs        | 104 +++++++++++++++++++++
 .../sqllogictest/test_files/push_down_filter.slt   |  64 +++++++++++++
 4 files changed, 171 insertions(+), 2 deletions(-)

diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs 
b/datafusion/optimizer/src/analyzer/type_coercion.rs
index 07eb795462..a77249424f 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -296,6 +296,7 @@ impl<'a> TypeCoercionRewriter<'a> {
             &right.get_type(right_schema)?,
         )
         .get_input_types()?;
+
         Ok((
             left.cast_to(&left_type, left_schema)?,
             right.cast_to(&right_type, right_schema)?,
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs 
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index ce10c7e5c6..9003467703 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -1758,7 +1758,7 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, 
S> {
             // try_cast/cast(expr as data_type) op literal
             Expr::BinaryExpr(BinaryExpr { left, op, right })
                 if 
is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
-                    info, &left, &right,
+                    info, &left, op, &right,
                 ) && op.supports_propagation() =>
             {
                 unwrap_cast_in_comparison_for_binary(info, left, right, op)?
@@ -1768,7 +1768,7 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, 
S> {
             // try_cast/cast(expr as data_type) op_swap literal
             Expr::BinaryExpr(BinaryExpr { left, op, right })
                 if 
is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
-                    info, &right, &left,
+                    info, &right, op, &left,
                 ) && op.supports_propagation()
                     && op.swap().is_some() =>
             {
diff --git a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs 
b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs
index 7670bdf98b..be71a8cd19 100644
--- a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs
+++ b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs
@@ -81,6 +81,16 @@ pub(super) fn unwrap_cast_in_comparison_for_binary<S: 
SimplifyInfo>(
             let Ok(expr_type) = info.get_data_type(&expr) else {
                 return internal_err!("Can't get the data type of the expr 
{:?}", &expr);
             };
+
+            if let Some(value) = cast_literal_to_type_with_op(&lit_value, 
&expr_type, op)
+            {
+                return Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr {
+                    left: expr,
+                    op,
+                    right: Box::new(lit(value)),
+                })));
+            };
+
             // if the lit_value can be casted to the type of internal_left_expr
             // we need to unwrap the cast for cast/try_cast expr, and add cast 
to the literal
             let Some(value) = try_cast_literal_to_type(&lit_value, &expr_type) 
else {
@@ -105,6 +115,7 @@ pub(super) fn 
is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary<
 >(
     info: &S,
     expr: &Expr,
+    op: Operator,
     literal: &Expr,
 ) -> bool {
     match (expr, literal) {
@@ -125,6 +136,10 @@ pub(super) fn 
is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary<
                 return false;
             };
 
+            if cast_literal_to_type_with_op(lit_val, &expr_type, op).is_some() 
{
+                return true;
+            }
+
             try_cast_literal_to_type(lit_val, &expr_type).is_some()
                 && is_supported_type(&expr_type)
                 && is_supported_type(&lit_type)
@@ -215,6 +230,52 @@ fn is_supported_dictionary_type(data_type: &DataType) -> 
bool {
                     DataType::Dictionary(_, inner) if is_supported_type(inner))
 }
 
+///// Tries to move a cast from an expression (such as column) to the literal 
other side of a comparison operator./
+///
+/// Specifically, rewrites
+/// ```sql
+/// cast(col) <op> <literal>
+/// ```
+///
+/// To
+///
+/// ```sql
+/// col <op> cast(<literal>)
+/// col <op> <casted_literal>
+/// ```
+fn cast_literal_to_type_with_op(
+    lit_value: &ScalarValue,
+    target_type: &DataType,
+    op: Operator,
+) -> Option<ScalarValue> {
+    match (op, lit_value) {
+        (
+            Operator::Eq | Operator::NotEq,
+            ScalarValue::Utf8(Some(_))
+            | ScalarValue::Utf8View(Some(_))
+            | ScalarValue::LargeUtf8(Some(_)),
+        ) => {
+            // Only try for integer types (TODO can we do this for other types
+            // like timestamps)?
+            use DataType::*;
+            if matches!(
+                target_type,
+                Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64
+            ) {
+                let casted = lit_value.cast_to(target_type).ok()?;
+                let round_tripped = 
casted.cast_to(&lit_value.data_type()).ok()?;
+                if lit_value != &round_tripped {
+                    return None;
+                }
+                Some(casted)
+            } else {
+                None
+            }
+        }
+        _ => None,
+    }
+}
+
 /// Convert a literal value from one data type to another
 pub(super) fn try_cast_literal_to_type(
     lit_value: &ScalarValue,
@@ -468,6 +529,24 @@ mod tests {
         // the 99999999999 is not within the range of MAX(int32) and 
MIN(int32), we don't cast the lit(99999999999) to int32 type
         let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(99999999999i64));
         assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
+
+        // cast(c1, UTF8) < '123', only eq/not_eq should be optimized
+        let expr_lt = cast(col("c1"), DataType::Utf8).lt(lit("123"));
+        assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
+
+        // cast(c1, UTF8) = '0123', cast(cast('0123', Int32), UTF8) != '0123', 
so '0123' should not
+        // be casted
+        let expr_lt = cast(col("c1"), DataType::Utf8).lt(lit("0123"));
+        assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
+
+        // cast(c1, UTF8) = 'not a number', should not be able to cast to 
column type
+        let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("not a 
number"));
+        assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);
+
+        // cast(c1, UTF8) = '99999999999', where '99999999999' does not fit 
into int32, so it will
+        // not be optimized to integer comparison
+        let expr_input = cast(col("c1"), 
DataType::Utf8).eq(lit("99999999999"));
+        assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);
     }
 
     #[test]
@@ -496,6 +575,21 @@ mod tests {
         let lit_lt_lit = cast(null_i8(), DataType::Int32).lt(lit(12i32));
         let expected = null_bool();
         assert_eq!(optimize_test(lit_lt_lit, &schema), expected);
+
+        // cast(c1, UTF8) = '123' => c1 = 123
+        let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("123"));
+        let expected = col("c1").eq(lit(123i32));
+        assert_eq!(optimize_test(expr_input, &schema), expected);
+
+        // cast(c1, UTF8) != '123' => c1 != 123
+        let expr_input = cast(col("c1"), DataType::Utf8).not_eq(lit("123"));
+        let expected = col("c1").not_eq(lit(123i32));
+        assert_eq!(optimize_test(expr_input, &schema), expected);
+
+        // cast(c1, UTF8) = NULL => c1 = NULL
+        let expr_input = cast(col("c1"), 
DataType::Utf8).eq(lit(ScalarValue::Utf8(None)));
+        let expected = col("c1").eq(lit(ScalarValue::Int32(None)));
+        assert_eq!(optimize_test(expr_input, &schema), expected);
     }
 
     #[test]
@@ -505,6 +599,16 @@ mod tests {
         let expr_input = cast(col("c6"), DataType::UInt64).eq(lit(0u64));
         let expected = col("c6").eq(lit(0u32));
         assert_eq!(optimize_test(expr_input, &schema), expected);
+
+        // cast(c6, UTF8) = "123" => c6 = 123
+        let expr_input = cast(col("c6"), DataType::Utf8).eq(lit("123"));
+        let expected = col("c6").eq(lit(123u32));
+        assert_eq!(optimize_test(expr_input, &schema), expected);
+
+        // cast(c6, UTF8) != "123" => c6 != 123
+        let expr_input = cast(col("c6"), DataType::Utf8).not_eq(lit("123"));
+        let expected = col("c6").not_eq(lit(123u32));
+        assert_eq!(optimize_test(expr_input, &schema), expected);
     }
 
     #[test]
diff --git a/datafusion/sqllogictest/test_files/push_down_filter.slt 
b/datafusion/sqllogictest/test_files/push_down_filter.slt
index 521aa33409..67965146e7 100644
--- a/datafusion/sqllogictest/test_files/push_down_filter.slt
+++ b/datafusion/sqllogictest/test_files/push_down_filter.slt
@@ -188,6 +188,7 @@ select * from test_filter_with_limit where value = 2 limit 
1;
 ----
 2 2
 
+
 # Tear down test_filter_with_limit table:
 statement ok
 DROP TABLE test_filter_with_limit;
@@ -195,3 +196,66 @@ DROP TABLE test_filter_with_limit;
 # Tear down src_table table:
 statement ok
 DROP TABLE src_table;
+
+
+query I
+COPY (VALUES (1), (2), (3), (4), (5), (6), (7), (8), (9), (10))
+TO 'test_files/scratch/push_down_filter/t.parquet'
+STORED AS PARQUET;
+----
+10
+
+statement ok
+CREATE EXTERNAL TABLE t
+(
+  a INT
+)
+STORED AS PARQUET
+LOCATION 'test_files/scratch/push_down_filter/t.parquet';
+
+
+# The predicate should not have a column cast  when the value is a valid i32
+query TT
+explain select a from t where a = '100';
+----
+logical_plan TableScan: t projection=[a], full_filters=[t.a = Int32(100)]
+
+# The predicate should not have a column cast  when the value is a valid i32
+query TT
+explain select a from t where a != '100';
+----
+logical_plan TableScan: t projection=[a], full_filters=[t.a != Int32(100)]
+
+# The predicate should still have the column cast when the value is a NOT 
valid i32
+query TT
+explain select a from t where a = '99999999999';
+----
+logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = 
Utf8("99999999999")]
+
+# The predicate should still have the column cast when the value is a NOT 
valid i32
+query TT
+explain select a from t where a = '99.99';
+----
+logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = 
Utf8("99.99")]
+
+# The predicate should still have the column cast when the value is a NOT 
valid i32
+query TT
+explain select a from t where a = '';
+----
+logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = 
Utf8("")]
+
+# The predicate should not have a column cast when the operator is = or != and 
the literal can be round-trip casted without losing information.
+query TT
+explain select a from t where cast(a as string) = '100';
+----
+logical_plan TableScan: t projection=[a], full_filters=[t.a = Int32(100)]
+
+# The predicate should still have the column cast when the literal alters its 
string representation after round-trip casting (leading zero lost).
+query TT
+explain select a from t where CAST(a AS string) = '0123';
+----
+logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = 
Utf8("0123")]
+
+
+statement ok
+drop table t;


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to