This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 26c0c49ad1 perf: unwrap cast for comparing ints =/!= strings (#15110)
26c0c49ad1 is described below
commit 26c0c49ad1dd91e6cbb10b16d81faa3a6f361d81
Author: Li-Lun Lin <[email protected]>
AuthorDate: Fri Mar 28 03:41:33 2025 +0800
perf: unwrap cast for comparing ints =/!= strings (#15110)
* perf: unwrap cast for comparing ints =/!= strings
* fix: update casting logic
* test: add more unit test and new sqllogictest
* Tweak slt tests
* Revert "perf: unwrap cast for comparing ints =/!= strings"
This reverts commit 808d6ab3ceb0281d055965a330b8ffb1c47fa65b.
* fix: eliminate column cast and cast literal before coercion
* fix: physical expr coercion test
* feat: unwrap cast after round-trip cast verification
* fix: unwrap cast on round-trip cast stable strings
* revert: remove avoid cast changes
* refactor: apply review suggestions
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/optimizer/src/analyzer/type_coercion.rs | 1 +
.../src/simplify_expressions/expr_simplifier.rs | 4 +-
.../src/simplify_expressions/unwrap_cast.rs | 104 +++++++++++++++++++++
.../sqllogictest/test_files/push_down_filter.slt | 64 +++++++++++++
4 files changed, 171 insertions(+), 2 deletions(-)
diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs
b/datafusion/optimizer/src/analyzer/type_coercion.rs
index 07eb795462..a77249424f 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -296,6 +296,7 @@ impl<'a> TypeCoercionRewriter<'a> {
&right.get_type(right_schema)?,
)
.get_input_types()?;
+
Ok((
left.cast_to(&left_type, left_schema)?,
right.cast_to(&right_type, right_schema)?,
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index ce10c7e5c6..9003467703 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -1758,7 +1758,7 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_,
S> {
// try_cast/cast(expr as data_type) op literal
Expr::BinaryExpr(BinaryExpr { left, op, right })
if
is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
- info, &left, &right,
+ info, &left, op, &right,
) && op.supports_propagation() =>
{
unwrap_cast_in_comparison_for_binary(info, left, right, op)?
@@ -1768,7 +1768,7 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_,
S> {
// try_cast/cast(expr as data_type) op_swap literal
Expr::BinaryExpr(BinaryExpr { left, op, right })
if
is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
- info, &right, &left,
+ info, &right, op, &left,
) && op.supports_propagation()
&& op.swap().is_some() =>
{
diff --git a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs
b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs
index 7670bdf98b..be71a8cd19 100644
--- a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs
+++ b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs
@@ -81,6 +81,16 @@ pub(super) fn unwrap_cast_in_comparison_for_binary<S:
SimplifyInfo>(
let Ok(expr_type) = info.get_data_type(&expr) else {
return internal_err!("Can't get the data type of the expr
{:?}", &expr);
};
+
+ if let Some(value) = cast_literal_to_type_with_op(&lit_value,
&expr_type, op)
+ {
+ return Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr {
+ left: expr,
+ op,
+ right: Box::new(lit(value)),
+ })));
+ };
+
// if the lit_value can be casted to the type of internal_left_expr
// we need to unwrap the cast for cast/try_cast expr, and add cast
to the literal
let Some(value) = try_cast_literal_to_type(&lit_value, &expr_type)
else {
@@ -105,6 +115,7 @@ pub(super) fn
is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary<
>(
info: &S,
expr: &Expr,
+ op: Operator,
literal: &Expr,
) -> bool {
match (expr, literal) {
@@ -125,6 +136,10 @@ pub(super) fn
is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary<
return false;
};
+ if cast_literal_to_type_with_op(lit_val, &expr_type, op).is_some()
{
+ return true;
+ }
+
try_cast_literal_to_type(lit_val, &expr_type).is_some()
&& is_supported_type(&expr_type)
&& is_supported_type(&lit_type)
@@ -215,6 +230,52 @@ fn is_supported_dictionary_type(data_type: &DataType) ->
bool {
DataType::Dictionary(_, inner) if is_supported_type(inner))
}
+///// Tries to move a cast from an expression (such as column) to the literal
other side of a comparison operator./
+///
+/// Specifically, rewrites
+/// ```sql
+/// cast(col) <op> <literal>
+/// ```
+///
+/// To
+///
+/// ```sql
+/// col <op> cast(<literal>)
+/// col <op> <casted_literal>
+/// ```
+fn cast_literal_to_type_with_op(
+ lit_value: &ScalarValue,
+ target_type: &DataType,
+ op: Operator,
+) -> Option<ScalarValue> {
+ match (op, lit_value) {
+ (
+ Operator::Eq | Operator::NotEq,
+ ScalarValue::Utf8(Some(_))
+ | ScalarValue::Utf8View(Some(_))
+ | ScalarValue::LargeUtf8(Some(_)),
+ ) => {
+ // Only try for integer types (TODO can we do this for other types
+ // like timestamps)?
+ use DataType::*;
+ if matches!(
+ target_type,
+ Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64
+ ) {
+ let casted = lit_value.cast_to(target_type).ok()?;
+ let round_tripped =
casted.cast_to(&lit_value.data_type()).ok()?;
+ if lit_value != &round_tripped {
+ return None;
+ }
+ Some(casted)
+ } else {
+ None
+ }
+ }
+ _ => None,
+ }
+}
+
/// Convert a literal value from one data type to another
pub(super) fn try_cast_literal_to_type(
lit_value: &ScalarValue,
@@ -468,6 +529,24 @@ mod tests {
// the 99999999999 is not within the range of MAX(int32) and
MIN(int32), we don't cast the lit(99999999999) to int32 type
let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(99999999999i64));
assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
+
+ // cast(c1, UTF8) < '123', only eq/not_eq should be optimized
+ let expr_lt = cast(col("c1"), DataType::Utf8).lt(lit("123"));
+ assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
+
+ // cast(c1, UTF8) = '0123', cast(cast('0123', Int32), UTF8) != '0123',
so '0123' should not
+ // be casted
+ let expr_lt = cast(col("c1"), DataType::Utf8).lt(lit("0123"));
+ assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
+
+ // cast(c1, UTF8) = 'not a number', should not be able to cast to
column type
+ let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("not a
number"));
+ assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);
+
+ // cast(c1, UTF8) = '99999999999', where '99999999999' does not fit
into int32, so it will
+ // not be optimized to integer comparison
+ let expr_input = cast(col("c1"),
DataType::Utf8).eq(lit("99999999999"));
+ assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);
}
#[test]
@@ -496,6 +575,21 @@ mod tests {
let lit_lt_lit = cast(null_i8(), DataType::Int32).lt(lit(12i32));
let expected = null_bool();
assert_eq!(optimize_test(lit_lt_lit, &schema), expected);
+
+ // cast(c1, UTF8) = '123' => c1 = 123
+ let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("123"));
+ let expected = col("c1").eq(lit(123i32));
+ assert_eq!(optimize_test(expr_input, &schema), expected);
+
+ // cast(c1, UTF8) != '123' => c1 != 123
+ let expr_input = cast(col("c1"), DataType::Utf8).not_eq(lit("123"));
+ let expected = col("c1").not_eq(lit(123i32));
+ assert_eq!(optimize_test(expr_input, &schema), expected);
+
+ // cast(c1, UTF8) = NULL => c1 = NULL
+ let expr_input = cast(col("c1"),
DataType::Utf8).eq(lit(ScalarValue::Utf8(None)));
+ let expected = col("c1").eq(lit(ScalarValue::Int32(None)));
+ assert_eq!(optimize_test(expr_input, &schema), expected);
}
#[test]
@@ -505,6 +599,16 @@ mod tests {
let expr_input = cast(col("c6"), DataType::UInt64).eq(lit(0u64));
let expected = col("c6").eq(lit(0u32));
assert_eq!(optimize_test(expr_input, &schema), expected);
+
+ // cast(c6, UTF8) = "123" => c6 = 123
+ let expr_input = cast(col("c6"), DataType::Utf8).eq(lit("123"));
+ let expected = col("c6").eq(lit(123u32));
+ assert_eq!(optimize_test(expr_input, &schema), expected);
+
+ // cast(c6, UTF8) != "123" => c6 != 123
+ let expr_input = cast(col("c6"), DataType::Utf8).not_eq(lit("123"));
+ let expected = col("c6").not_eq(lit(123u32));
+ assert_eq!(optimize_test(expr_input, &schema), expected);
}
#[test]
diff --git a/datafusion/sqllogictest/test_files/push_down_filter.slt
b/datafusion/sqllogictest/test_files/push_down_filter.slt
index 521aa33409..67965146e7 100644
--- a/datafusion/sqllogictest/test_files/push_down_filter.slt
+++ b/datafusion/sqllogictest/test_files/push_down_filter.slt
@@ -188,6 +188,7 @@ select * from test_filter_with_limit where value = 2 limit
1;
----
2 2
+
# Tear down test_filter_with_limit table:
statement ok
DROP TABLE test_filter_with_limit;
@@ -195,3 +196,66 @@ DROP TABLE test_filter_with_limit;
# Tear down src_table table:
statement ok
DROP TABLE src_table;
+
+
+query I
+COPY (VALUES (1), (2), (3), (4), (5), (6), (7), (8), (9), (10))
+TO 'test_files/scratch/push_down_filter/t.parquet'
+STORED AS PARQUET;
+----
+10
+
+statement ok
+CREATE EXTERNAL TABLE t
+(
+ a INT
+)
+STORED AS PARQUET
+LOCATION 'test_files/scratch/push_down_filter/t.parquet';
+
+
+# The predicate should not have a column cast when the value is a valid i32
+query TT
+explain select a from t where a = '100';
+----
+logical_plan TableScan: t projection=[a], full_filters=[t.a = Int32(100)]
+
+# The predicate should not have a column cast when the value is a valid i32
+query TT
+explain select a from t where a != '100';
+----
+logical_plan TableScan: t projection=[a], full_filters=[t.a != Int32(100)]
+
+# The predicate should still have the column cast when the value is a NOT
valid i32
+query TT
+explain select a from t where a = '99999999999';
+----
+logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) =
Utf8("99999999999")]
+
+# The predicate should still have the column cast when the value is a NOT
valid i32
+query TT
+explain select a from t where a = '99.99';
+----
+logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) =
Utf8("99.99")]
+
+# The predicate should still have the column cast when the value is a NOT
valid i32
+query TT
+explain select a from t where a = '';
+----
+logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) =
Utf8("")]
+
+# The predicate should not have a column cast when the operator is = or != and
the literal can be round-trip casted without losing information.
+query TT
+explain select a from t where cast(a as string) = '100';
+----
+logical_plan TableScan: t projection=[a], full_filters=[t.a = Int32(100)]
+
+# The predicate should still have the column cast when the literal alters its
string representation after round-trip casting (leading zero lost).
+query TT
+explain select a from t where CAST(a AS string) = '0123';
+----
+logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) =
Utf8("0123")]
+
+
+statement ok
+drop table t;
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]