This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new eb4ee6272c Move `UnwrapCastInComparison` into `Simplifier` (#15012)
eb4ee6272c is described below
commit eb4ee6272c77a2724c75edc714a93a1dd3e2c13d
Author: Jay Zhan <[email protected]>
AuthorDate: Thu Mar 6 19:35:20 2025 +0800
Move `UnwrapCastInComparison` into `Simplifier` (#15012)
* add unwrap in simplify expr
* rm unwrap cast
* return err
* rename
* fix
* fmt
* add unwrap_cast module to simplify expressions
* tweak comment
* Move tests
* Rewrite to use simplifier schema
* Update tests for simplify logic
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/core/tests/sql/explain_analyze.rs | 7 +-
datafusion/optimizer/src/lib.rs | 1 -
datafusion/optimizer/src/optimizer.rs | 3 -
.../src/simplify_expressions/expr_simplifier.rs | 92 +++-
.../optimizer/src/simplify_expressions/mod.rs | 1 +
.../unwrap_cast.rs} | 461 +++++++++------------
datafusion/sqllogictest/test_files/explain.slt | 4 -
7 files changed, 291 insertions(+), 278 deletions(-)
diff --git a/datafusion/core/tests/sql/explain_analyze.rs
b/datafusion/core/tests/sql/explain_analyze.rs
index 3bdc71a8eb..e8ef34c2af 100644
--- a/datafusion/core/tests/sql/explain_analyze.rs
+++ b/datafusion/core/tests/sql/explain_analyze.rs
@@ -355,7 +355,8 @@ async fn csv_explain_verbose() {
async fn csv_explain_inlist_verbose() {
let ctx = SessionContext::new();
register_aggregate_csv_by_sql(&ctx).await;
- let sql = "EXPLAIN VERBOSE SELECT c1 FROM aggregate_test_100 where c2 in
(1,2,4)";
+ // Inlist len <=3 case will be transformed to OR List so we test with len=4
+ let sql = "EXPLAIN VERBOSE SELECT c1 FROM aggregate_test_100 where c2 in
(1,2,4,5)";
let actual = execute(&ctx, sql).await;
// Optimized by PreCastLitInComparisonExpressions rule
@@ -368,12 +369,12 @@ async fn csv_explain_inlist_verbose() {
// before optimization (Int64 literals)
assert_contains!(
&actual,
- "aggregate_test_100.c2 IN ([Int64(1), Int64(2), Int64(4)])"
+ "aggregate_test_100.c2 IN ([Int64(1), Int64(2), Int64(4), Int64(5)])"
);
// after optimization (casted to Int8)
assert_contains!(
&actual,
- "aggregate_test_100.c2 IN ([Int8(1), Int8(2), Int8(4)])"
+ "aggregate_test_100.c2 IN ([Int8(1), Int8(2), Int8(4), Int8(5)])"
);
}
diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs
index 61ca9b31cd..1280bf2f46 100644
--- a/datafusion/optimizer/src/lib.rs
+++ b/datafusion/optimizer/src/lib.rs
@@ -60,7 +60,6 @@ pub mod replace_distinct_aggregate;
pub mod scalar_subquery_to_join;
pub mod simplify_expressions;
pub mod single_distinct_to_groupby;
-pub mod unwrap_cast_in_comparison;
pub mod utils;
#[cfg(test)]
diff --git a/datafusion/optimizer/src/optimizer.rs
b/datafusion/optimizer/src/optimizer.rs
index 49bce3c1ce..018ad8ace0 100644
--- a/datafusion/optimizer/src/optimizer.rs
+++ b/datafusion/optimizer/src/optimizer.rs
@@ -54,7 +54,6 @@ use
crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate;
use crate::scalar_subquery_to_join::ScalarSubqueryToJoin;
use crate::simplify_expressions::SimplifyExpressions;
use crate::single_distinct_to_groupby::SingleDistinctToGroupBy;
-use crate::unwrap_cast_in_comparison::UnwrapCastInComparison;
use crate::utils::log_plan;
/// `OptimizerRule`s transforms one [`LogicalPlan`] into another which
@@ -243,7 +242,6 @@ impl Optimizer {
let rules: Vec<Arc<dyn OptimizerRule + Sync + Send>> = vec![
Arc::new(EliminateNestedUnion::new()),
Arc::new(SimplifyExpressions::new()),
- Arc::new(UnwrapCastInComparison::new()),
Arc::new(ReplaceDistinctWithAggregate::new()),
Arc::new(EliminateJoin::new()),
Arc::new(DecorrelatePredicateSubquery::new()),
@@ -266,7 +264,6 @@ impl Optimizer {
// The previous optimizations added expressions and projections,
// that might benefit from the following rules
Arc::new(SimplifyExpressions::new()),
- Arc::new(UnwrapCastInComparison::new()),
Arc::new(CommonSubexprEliminate::new()),
Arc::new(EliminateGroupByConstant::new()),
Arc::new(OptimizeProjections::new()),
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index 840c108905..d5a1b84e6a 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -32,7 +32,6 @@ use datafusion_common::{
tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter},
};
use datafusion_common::{internal_err, DFSchema, DataFusionError, Result,
ScalarValue};
-use datafusion_expr::simplify::ExprSimplifyResult;
use datafusion_expr::{
and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator,
Volatility,
WindowFunctionDefinition,
@@ -42,14 +41,23 @@ use datafusion_expr::{
expr::{InList, InSubquery, WindowFunction},
utils::{iter_conjunction, iter_conjunction_owned},
};
+use datafusion_expr::{simplify::ExprSimplifyResult, Cast, TryCast};
use datafusion_physical_expr::{create_physical_expr,
execution_props::ExecutionProps};
use super::inlist_simplifier::ShortenInListSimplifier;
use super::utils::*;
-use crate::analyzer::type_coercion::TypeCoercionRewriter;
use crate::simplify_expressions::guarantees::GuaranteeRewriter;
use crate::simplify_expressions::regex::simplify_regex_expr;
+use crate::simplify_expressions::unwrap_cast::{
+ is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary,
+ is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist,
+ unwrap_cast_in_comparison_for_binary,
+};
use crate::simplify_expressions::SimplifyInfo;
+use crate::{
+ analyzer::type_coercion::TypeCoercionRewriter,
+ simplify_expressions::unwrap_cast::try_cast_literal_to_type,
+};
use indexmap::IndexSet;
use regex::Regex;
@@ -1742,6 +1750,86 @@ impl<S: SimplifyInfo> TreeNodeRewriter for
Simplifier<'_, S> {
}
}
+ // =======================================
+ // unwrap_cast_in_comparison
+ // =======================================
+ //
+ // For case:
+ // try_cast/cast(expr as data_type) op literal
+ Expr::BinaryExpr(BinaryExpr { left, op, right })
+ if
is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
+ info, &left, &right,
+ ) && op.supports_propagation() =>
+ {
+ unwrap_cast_in_comparison_for_binary(info, left, right, op)?
+ }
+ // literal op try_cast/cast(expr as data_type)
+ // -->
+ // try_cast/cast(expr as data_type) op_swap literal
+ Expr::BinaryExpr(BinaryExpr { left, op, right })
+ if
is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
+ info, &right, &left,
+ ) && op.supports_propagation()
+ && op.swap().is_some() =>
+ {
+ unwrap_cast_in_comparison_for_binary(
+ info,
+ right,
+ left,
+ op.swap().unwrap(),
+ )?
+ }
+ // For case:
+ // try_cast/cast(expr as left_type) in (expr1,expr2,expr3)
+ Expr::InList(InList {
+ expr: mut left,
+ list,
+ negated,
+ }) if
is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist(
+ info, &left, &list,
+ ) =>
+ {
+ let (Expr::TryCast(TryCast {
+ expr: left_expr, ..
+ })
+ | Expr::Cast(Cast {
+ expr: left_expr, ..
+ })) = left.as_mut()
+ else {
+ return internal_err!("Expect cast expr, but got {:?}",
left)?;
+ };
+
+ let expr_type = info.get_data_type(left_expr)?;
+ let right_exprs = list
+ .into_iter()
+ .map(|right| {
+ match right {
+ Expr::Literal(right_lit_value) => {
+ // if the right_lit_value can be casted to the
type of internal_left_expr
+ // we need to unwrap the cast for
cast/try_cast expr, and add cast to the literal
+ let Some(value) =
try_cast_literal_to_type(&right_lit_value, &expr_type) else {
+ internal_err!(
+ "Can't cast the list expr {:?} to type
{:?}",
+ right_lit_value, &expr_type
+ )?
+ };
+ Ok(lit(value))
+ }
+ other_expr => internal_err!(
+ "Only support literal expr to optimize, but
the expr is {:?}",
+ &other_expr
+ ),
+ }
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ Transformed::yes(Expr::InList(InList {
+ expr: std::mem::take(left_expr),
+ list: right_exprs,
+ negated,
+ }))
+ }
+
// no additional rewrites possible
expr => Transformed::no(expr),
})
diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs
b/datafusion/optimizer/src/simplify_expressions/mod.rs
index 46c066c11c..5fbee02e39 100644
--- a/datafusion/optimizer/src/simplify_expressions/mod.rs
+++ b/datafusion/optimizer/src/simplify_expressions/mod.rs
@@ -23,6 +23,7 @@ mod guarantees;
mod inlist_simplifier;
mod regex;
pub mod simplify_exprs;
+mod unwrap_cast;
mod utils;
// backwards compatibility
diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs
similarity index 79%
rename from datafusion/optimizer/src/unwrap_cast_in_comparison.rs
rename to datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs
index e2b8a966cb..7670bdf98b 100644
--- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
+++ b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs
@@ -15,274 +15,176 @@
// specific language governing permissions and limitations
// under the License.
-//! [`UnwrapCastInComparison`] rewrites `CAST(col) = lit` to `col = CAST(lit)`
+//! Unwrap casts in binary comparisons
+//!
+//! The functions in this module attempt to remove casts from
+//! comparisons to literals ([`ScalarValue`]s) by applying the casts
+//! to the literals if possible. It is inspired by the optimizer rule
+//! `UnwrapCastInBinaryComparison` of Spark.
+//!
+//! Removing casts often improves performance because:
+//! 1. The cast is done once (to the literal) rather than to every value
+//! 2. Can enable other optimizations such as predicate pushdown that
+//! don't support casting
+//!
+//! The rule is applied to expressions of the following forms:
+//!
+//! 1. `cast(left_expr as data_type) comparison_op literal_expr`
+//! 2. `literal_expr comparison_op cast(left_expr as data_type)`
+//! 3. `cast(literal_expr) IN (expr1, expr2, ...)`
+//! 4. `literal_expr IN (cast(expr1) , cast(expr2), ...)`
+//!
+//! If the expression matches one of the forms above, the rule will
+//! ensure the value of `literal` is in range(min, max) of the
+//! expr's data_type, and if the scalar is within range, the literal
+//! will be casted to the data type of expr on the other side, and the
+//! cast will be removed from the other side.
+//!
+//! # Example
+//!
+//! If the DataType of c1 is INT32. Given the filter
+//!
+//! ```text
+//! cast(c1 as INT64) > INT64(10)`
+//! ```
+//!
+//! This rule will remove the cast and rewrite the expression to:
+//!
+//! ```text
+//! c1 > INT32(10)
+//! ```
+//!
use std::cmp::Ordering;
-use std::mem;
-use std::sync::Arc;
-use crate::optimizer::ApplyOrder;
-use crate::{OptimizerConfig, OptimizerRule};
-
-use crate::utils::NamePreserver;
use arrow::datatypes::{
DataType, TimeUnit, MAX_DECIMAL128_FOR_EACH_PRECISION,
MIN_DECIMAL128_FOR_EACH_PRECISION,
};
use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS};
-use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
-use datafusion_common::{internal_err, DFSchema, DFSchemaRef, Result,
ScalarValue};
-use datafusion_expr::expr::{BinaryExpr, Cast, InList, TryCast};
-use datafusion_expr::utils::merge_schema;
-use datafusion_expr::{lit, Expr, ExprSchemable, LogicalPlan};
-
-/// [`UnwrapCastInComparison`] attempts to remove casts from
-/// comparisons to literals ([`ScalarValue`]s) by applying the casts
-/// to the literals if possible. It is inspired by the optimizer rule
-/// `UnwrapCastInBinaryComparison` of Spark.
-///
-/// Removing casts often improves performance because:
-/// 1. The cast is done once (to the literal) rather than to every value
-/// 2. Can enable other optimizations such as predicate pushdown that
-/// don't support casting
-///
-/// The rule is applied to expressions of the following forms:
-///
-/// 1. `cast(left_expr as data_type) comparison_op literal_expr`
-/// 2. `literal_expr comparison_op cast(left_expr as data_type)`
-/// 3. `cast(literal_expr) IN (expr1, expr2, ...)`
-/// 4. `literal_expr IN (cast(expr1) , cast(expr2), ...)`
-///
-/// If the expression matches one of the forms above, the rule will
-/// ensure the value of `literal` is in range(min, max) of the
-/// expr's data_type, and if the scalar is within range, the literal
-/// will be casted to the data type of expr on the other side, and the
-/// cast will be removed from the other side.
-///
-/// # Example
-///
-/// If the DataType of c1 is INT32. Given the filter
-///
-/// ```text
-/// Filter: cast(c1 as INT64) > INT64(10)`
-/// ```
-///
-/// This rule will remove the cast and rewrite the expression to:
-///
-/// ```text
-/// Filter: c1 > INT32(10)
-/// ```
-///
-#[derive(Default, Debug)]
-pub struct UnwrapCastInComparison {}
-
-impl UnwrapCastInComparison {
- pub fn new() -> Self {
- Self::default()
+use datafusion_common::{internal_err, tree_node::Transformed};
+use datafusion_common::{Result, ScalarValue};
+use datafusion_expr::{lit, BinaryExpr};
+use datafusion_expr::{simplify::SimplifyInfo, Cast, Expr, Operator, TryCast};
+
+pub(super) fn unwrap_cast_in_comparison_for_binary<S: SimplifyInfo>(
+ info: &S,
+ cast_expr: Box<Expr>,
+ literal: Box<Expr>,
+ op: Operator,
+) -> Result<Transformed<Expr>> {
+ match (*cast_expr, *literal) {
+ (
+ Expr::TryCast(TryCast { expr, .. }) | Expr::Cast(Cast { expr, ..
}),
+ Expr::Literal(lit_value),
+ ) => {
+ let Ok(expr_type) = info.get_data_type(&expr) else {
+ return internal_err!("Can't get the data type of the expr
{:?}", &expr);
+ };
+ // if the lit_value can be casted to the type of internal_left_expr
+ // we need to unwrap the cast for cast/try_cast expr, and add cast
to the literal
+ let Some(value) = try_cast_literal_to_type(&lit_value, &expr_type)
else {
+ return internal_err!(
+ "Can't cast the literal expr {:?} to type {:?}",
+ &lit_value,
+ &expr_type
+ );
+ };
+ Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr {
+ left: expr,
+ op,
+ right: Box::new(lit(value)),
+ })))
+ }
+ _ => internal_err!("Expect cast expr and literal"),
}
}
-impl OptimizerRule for UnwrapCastInComparison {
- fn name(&self) -> &str {
- "unwrap_cast_in_comparison"
- }
-
- fn apply_order(&self) -> Option<ApplyOrder> {
- Some(ApplyOrder::BottomUp)
- }
+pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary<
+ S: SimplifyInfo,
+>(
+ info: &S,
+ expr: &Expr,
+ literal: &Expr,
+) -> bool {
+ match (expr, literal) {
+ (
+ Expr::TryCast(TryCast {
+ expr: left_expr, ..
+ })
+ | Expr::Cast(Cast {
+ expr: left_expr, ..
+ }),
+ Expr::Literal(lit_val),
+ ) => {
+ let Ok(expr_type) = info.get_data_type(left_expr) else {
+ return false;
+ };
- fn supports_rewrite(&self) -> bool {
- true
- }
+ let Ok(lit_type) = info.get_data_type(literal) else {
+ return false;
+ };
- fn rewrite(
- &self,
- plan: LogicalPlan,
- _config: &dyn OptimizerConfig,
- ) -> Result<Transformed<LogicalPlan>> {
- let mut schema = merge_schema(&plan.inputs());
-
- if let LogicalPlan::TableScan(ts) = &plan {
- let source_schema = DFSchema::try_from_qualified_schema(
- ts.table_name.clone(),
- &ts.source.schema(),
- )?;
- schema.merge(&source_schema);
+ try_cast_literal_to_type(lit_val, &expr_type).is_some()
+ && is_supported_type(&expr_type)
+ && is_supported_type(&lit_type)
}
+ _ => false,
+ }
+}
- schema.merge(plan.schema());
+pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist<
+ S: SimplifyInfo,
+>(
+ info: &S,
+ expr: &Expr,
+ list: &[Expr],
+) -> bool {
+ let (Expr::TryCast(TryCast {
+ expr: left_expr, ..
+ })
+ | Expr::Cast(Cast {
+ expr: left_expr, ..
+ })) = expr
+ else {
+ return false;
+ };
- let mut expr_rewriter = UnwrapCastExprRewriter {
- schema: Arc::new(schema),
- };
+ let Ok(expr_type) = info.get_data_type(left_expr) else {
+ return false;
+ };
- let name_preserver = NamePreserver::new(&plan);
- plan.map_expressions(|expr| {
- let original_name = name_preserver.save(&expr);
- expr.rewrite(&mut expr_rewriter)
- .map(|transformed| transformed.update_data(|e|
original_name.restore(e)))
- })
+ if !is_supported_type(&expr_type) {
+ return false;
}
-}
-struct UnwrapCastExprRewriter {
- schema: DFSchemaRef,
-}
+ for right in list {
+ let Ok(right_type) = info.get_data_type(right) else {
+ return false;
+ };
-impl TreeNodeRewriter for UnwrapCastExprRewriter {
- type Node = Expr;
-
- fn f_up(&mut self, mut expr: Expr) -> Result<Transformed<Expr>> {
- match &mut expr {
- // For case:
- // try_cast/cast(expr as data_type) op literal
- // literal op try_cast/cast(expr as data_type)
- Expr::BinaryExpr(BinaryExpr { left, op, right })
- if {
- let Ok(left_type) = left.get_type(&self.schema) else {
- return Ok(Transformed::no(expr));
- };
- let Ok(right_type) = right.get_type(&self.schema) else {
- return Ok(Transformed::no(expr));
- };
- is_supported_type(&left_type)
- && is_supported_type(&right_type)
- && op.supports_propagation()
- } =>
- {
- match (left.as_mut(), right.as_mut()) {
- (
- Expr::Literal(left_lit_value),
- Expr::TryCast(TryCast {
- expr: right_expr, ..
- })
- | Expr::Cast(Cast {
- expr: right_expr, ..
- }),
- ) => {
- // if the left_lit_value can be cast to the type of
expr
- // we need to unwrap the cast for cast/try_cast expr,
and add cast to the literal
- let Ok(expr_type) = right_expr.get_type(&self.schema)
else {
- return Ok(Transformed::no(expr));
- };
- match expr_type {
- //
https://github.com/apache/datafusion/issues/12180
- DataType::Utf8View => Ok(Transformed::no(expr)),
- _ => {
- let Some(value) =
- try_cast_literal_to_type(left_lit_value,
&expr_type)
- else {
- return Ok(Transformed::no(expr));
- };
- **left = lit(value);
- // unwrap the cast/try_cast for the right expr
- **right = mem::take(right_expr);
- Ok(Transformed::yes(expr))
- }
- }
- }
- (
- Expr::TryCast(TryCast {
- expr: left_expr, ..
- })
- | Expr::Cast(Cast {
- expr: left_expr, ..
- }),
- Expr::Literal(right_lit_value),
- ) => {
- // if the right_lit_value can be cast to the type of
expr
- // we need to unwrap the cast for cast/try_cast expr,
and add cast to the literal
- let Ok(expr_type) = left_expr.get_type(&self.schema)
else {
- return Ok(Transformed::no(expr));
- };
- match expr_type {
- //
https://github.com/apache/datafusion/issues/12180
- DataType::Utf8View => Ok(Transformed::no(expr)),
- _ => {
- let Some(value) =
- try_cast_literal_to_type(right_lit_value,
&expr_type)
- else {
- return Ok(Transformed::no(expr));
- };
- // unwrap the cast/try_cast for the left expr
- **left = mem::take(left_expr);
- **right = lit(value);
- Ok(Transformed::yes(expr))
- }
- }
- }
- _ => Ok(Transformed::no(expr)),
- }
- }
- // For case:
- // try_cast/cast(expr as left_type) in (expr1,expr2,expr3)
- Expr::InList(InList {
- expr: left, list, ..
- }) => {
- let (Expr::TryCast(TryCast {
- expr: left_expr, ..
- })
- | Expr::Cast(Cast {
- expr: left_expr, ..
- })) = left.as_mut()
- else {
- return Ok(Transformed::no(expr));
- };
- let Ok(expr_type) = left_expr.get_type(&self.schema) else {
- return Ok(Transformed::no(expr));
- };
- if !is_supported_type(&expr_type) {
- return Ok(Transformed::no(expr));
- }
- let Ok(right_exprs) = list
- .iter()
- .map(|right| {
- let right_type = right.get_type(&self.schema)?;
- if !is_supported_type(&right_type) {
- internal_err!(
- "The type of list expr {} is not supported",
- &right_type
- )?;
- }
- match right {
- Expr::Literal(right_lit_value) => {
- // if the right_lit_value can be casted to the
type of internal_left_expr
- // we need to unwrap the cast for
cast/try_cast expr, and add cast to the literal
- let Some(value) =
try_cast_literal_to_type(right_lit_value, &expr_type) else {
- internal_err!(
- "Can't cast the list expr {:?} to type
{:?}",
- right_lit_value, &expr_type
- )?
- };
- Ok(lit(value))
- }
- other_expr => internal_err!(
- "Only support literal expr to optimize, but
the expr is {:?}",
- &other_expr
- ),
- }
- })
- .collect::<Result<Vec<_>>>() else {
- return Ok(Transformed::no(expr))
- };
- **left = mem::take(left_expr);
- *list = right_exprs;
- Ok(Transformed::yes(expr))
- }
- // TODO: handle other expr type and dfs visit them
- _ => Ok(Transformed::no(expr)),
+ if !is_supported_type(&right_type) {
+ return false;
+ }
+
+ match right {
+ Expr::Literal(lit_val)
+ if try_cast_literal_to_type(lit_val, &expr_type).is_some() =>
{}
+ _ => return false,
}
}
+
+ true
}
-/// Returns true if [UnwrapCastExprRewriter] supports this data type
+/// Returns true if unwrap_cast_in_comparison supports this data type
fn is_supported_type(data_type: &DataType) -> bool {
is_supported_numeric_type(data_type)
|| is_supported_string_type(data_type)
|| is_supported_dictionary_type(data_type)
}
-/// Returns true if [[UnwrapCastExprRewriter]] support this numeric type
+/// Returns true if unwrap_cast_in_comparison support this numeric type
fn is_supported_numeric_type(data_type: &DataType) -> bool {
matches!(
data_type,
@@ -299,7 +201,7 @@ fn is_supported_numeric_type(data_type: &DataType) -> bool {
)
}
-/// Returns true if [UnwrapCastExprRewriter] supports casting this value as a
string
+/// Returns true if unwrap_cast_in_comparison supports casting this value as a
string
fn is_supported_string_type(data_type: &DataType) -> bool {
matches!(
data_type,
@@ -307,14 +209,14 @@ fn is_supported_string_type(data_type: &DataType) -> bool
{
)
}
-/// Returns true if [UnwrapCastExprRewriter] supports casting this value as a
dictionary
+/// Returns true if unwrap_cast_in_comparison supports casting this value as a
dictionary
fn is_supported_dictionary_type(data_type: &DataType) -> bool {
matches!(data_type,
DataType::Dictionary(_, inner) if is_supported_type(inner))
}
/// Convert a literal value from one data type to another
-fn try_cast_literal_to_type(
+pub(super) fn try_cast_literal_to_type(
lit_value: &ScalarValue,
target_type: &DataType,
) -> Option<ScalarValue> {
@@ -540,13 +442,16 @@ fn cast_between_timestamp(from: &DataType, to: &DataType,
value: i128) -> Option
#[cfg(test)]
mod tests {
- use std::collections::HashMap;
-
use super::*;
+ use std::collections::HashMap;
+ use std::sync::Arc;
+ use crate::simplify_expressions::ExprSimplifier;
use arrow::compute::{cast_with_options, CastOptions};
use arrow::datatypes::Field;
- use datafusion_common::tree_node::TransformedResult;
+ use datafusion_common::{DFSchema, DFSchemaRef};
+ use datafusion_expr::execution_props::ExecutionProps;
+ use datafusion_expr::simplify::SimplifyContext;
use datafusion_expr::{cast, col, in_list, try_cast};
#[test]
@@ -587,9 +492,9 @@ mod tests {
let expected = col("c1").lt(null_i32());
assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected);
- // cast(INT8(NULL), INT32) < INT32(12) => INT8(NULL) < INT8(12)
+ // cast(INT8(NULL), INT32) < INT32(12) => INT8(NULL) < INT8(12) =>
BOOL(NULL)
let lit_lt_lit = cast(null_i8(), DataType::Int32).lt(lit(12i32));
- let expected = null_i8().lt(lit(12i8));
+ let expected = null_bool();
assert_eq!(optimize_test(lit_lt_lit, &schema), expected);
}
@@ -623,7 +528,7 @@ mod tests {
// Verify reversed argument order
// arrow_cast('value', 'Dictionary<Int32, Utf8>') = cast(str1 as
Dictionary<Int32, Utf8>) => Utf8('value1') = str1
let expr_input = lit(dict.clone()).eq(cast(col("str1"),
dict.data_type()));
- let expected = lit("value").eq(col("str1"));
+ let expected = col("str1").eq(lit("value"));
assert_eq!(optimize_test(expr_input, &schema), expected);
}
@@ -740,15 +645,27 @@ mod tests {
#[test]
fn test_unwrap_list_cast_comparison() {
let schema = expr_test_schema();
- // INT32(C1) IN (INT32(12),INT64(24)) -> INT32(C1) IN
(INT32(12),INT32(24))
- let expr_lt =
- cast(col("c1"), DataType::Int64).in_list(vec![lit(12i64),
lit(24i64)], false);
- let expected = col("c1").in_list(vec![lit(12i32), lit(24i32)], false);
+ // INT32(C1) IN (INT32(12),INT64(23),INT64(34),INT64(56),INT64(78)) ->
+ // INT32(C1) IN (INT32(12),INT32(23),INT32(34),INT32(56),INT32(78))
+ let expr_lt = cast(col("c1"), DataType::Int64).in_list(
+ vec![lit(12i64), lit(23i64), lit(34i64), lit(56i64), lit(78i64)],
+ false,
+ );
+ let expected = col("c1").in_list(
+ vec![lit(12i32), lit(23i32), lit(34i32), lit(56i32), lit(78i32)],
+ false,
+ );
assert_eq!(optimize_test(expr_lt, &schema), expected);
- // INT32(C2) IN (INT64(NULL),INT64(24)) -> INT32(C1) IN
(INT32(12),INT32(24))
- let expr_lt =
- cast(col("c2"), DataType::Int32).in_list(vec![null_i32(),
lit(14i32)], false);
- let expected = col("c2").in_list(vec![null_i64(), lit(14i64)], false);
+ // INT32(C2) IN (INT64(NULL),INT64(24),INT64(34),INT64(56),INT64(78))
->
+ // INT32(C2) IN (INT32(NULL),INT32(24),INT32(34),INT32(56),INT32(78))
+ let expr_lt = cast(col("c2"), DataType::Int32).in_list(
+ vec![null_i32(), lit(24i32), lit(34i64), lit(56i64), lit(78i64)],
+ false,
+ );
+ let expected = col("c2").in_list(
+ vec![null_i64(), lit(24i64), lit(34i64), lit(56i64), lit(78i64)],
+ false,
+ );
assert_eq!(optimize_test(expr_lt, &schema), expected);
@@ -774,10 +691,14 @@ mod tests {
);
assert_eq!(optimize_test(expr_lt, &schema), expected);
- // cast(INT32(12), INT64) IN (.....)
- let expr_lt = cast(lit(12i32), DataType::Int64)
- .in_list(vec![lit(13i64), lit(12i64)], false);
- let expected = lit(12i32).in_list(vec![lit(13i32), lit(12i32)], false);
+ // cast(INT32(12), INT64) IN (.....) =>
+ // INT64(12) IN (INT64(12),INT64(13),INT64(14),INT64(15),INT64(16))
+ // => true
+ let expr_lt = cast(lit(12i32), DataType::Int64).in_list(
+ vec![lit(12i64), lit(13i64), lit(14i64), lit(15i64), lit(16i64)],
+ false,
+ );
+ let expected = lit(true);
assert_eq!(optimize_test(expr_lt, &schema), expected);
}
@@ -815,8 +736,12 @@ mod tests {
assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);
// inlist for unsupported data type
- let expr_input =
- in_list(cast(col("c6"), DataType::Float64), vec![lit(0f64)],
false);
+ let expr_input = in_list(
+ cast(col("c6"), DataType::Float64),
+ // need more literals to avoid rewriting to binary expr
+ vec![lit(0f64), lit(1f64), lit(2f64), lit(3f64), lit(4f64)],
+ false,
+ );
assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);
}
@@ -833,10 +758,12 @@ mod tests {
}
fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr {
- let mut expr_rewriter = UnwrapCastExprRewriter {
- schema: Arc::clone(schema),
- };
- expr.rewrite(&mut expr_rewriter).data().unwrap()
+ let props = ExecutionProps::new();
+ let simplifier = ExprSimplifier::new(
+ SimplifyContext::new(&props).with_schema(Arc::clone(schema)),
+ );
+
+ simplifier.simplify(expr).unwrap()
}
fn expr_test_schema() -> DFSchemaRef {
@@ -862,6 +789,10 @@ mod tests {
)
}
+ fn null_bool() -> Expr {
+ lit(ScalarValue::Boolean(None))
+ }
+
fn null_i8() -> Expr {
lit(ScalarValue::Int8(None))
}
diff --git a/datafusion/sqllogictest/test_files/explain.slt
b/datafusion/sqllogictest/test_files/explain.slt
index d32ddd1512..cab7308f6f 100644
--- a/datafusion/sqllogictest/test_files/explain.slt
+++ b/datafusion/sqllogictest/test_files/explain.slt
@@ -181,7 +181,6 @@ logical_plan after type_coercion SAME TEXT AS ABOVE
analyzed_logical_plan SAME TEXT AS ABOVE
logical_plan after eliminate_nested_union SAME TEXT AS ABOVE
logical_plan after simplify_expressions SAME TEXT AS ABOVE
-logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE
logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE
logical_plan after eliminate_join SAME TEXT AS ABOVE
logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE
@@ -200,13 +199,11 @@ logical_plan after push_down_limit SAME TEXT AS ABOVE
logical_plan after push_down_filter SAME TEXT AS ABOVE
logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE
logical_plan after simplify_expressions SAME TEXT AS ABOVE
-logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE
logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE
logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE
logical_plan after optimize_projections TableScan: simple_explain_test
projection=[a, b, c]
logical_plan after eliminate_nested_union SAME TEXT AS ABOVE
logical_plan after simplify_expressions SAME TEXT AS ABOVE
-logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE
logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE
logical_plan after eliminate_join SAME TEXT AS ABOVE
logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE
@@ -225,7 +222,6 @@ logical_plan after push_down_limit SAME TEXT AS ABOVE
logical_plan after push_down_filter SAME TEXT AS ABOVE
logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE
logical_plan after simplify_expressions SAME TEXT AS ABOVE
-logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE
logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE
logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE
logical_plan after optimize_projections SAME TEXT AS ABOVE
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]