This is an automated email from the ASF dual-hosted git repository.
liukun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new a1b2112c2 change pre_cast_lit_in_comparison to
unwrap_cast_in_comparison (#3662)
a1b2112c2 is described below
commit a1b2112c219d4ec67a153de41ea894c68736e9f6
Author: Kun Liu <[email protected]>
AuthorDate: Fri Sep 30 22:35:50 2022 +0800
change pre_cast_lit_in_comparison to unwrap_cast_in_comparison (#3662)
* change pre_cast_lit_in_comparison to unwrap_cast_in_comparison
* change some test case
---
datafusion/core/src/execution/context.rs | 4 +-
datafusion/core/tests/sql/explain_analyze.rs | 20 +-
datafusion/optimizer/src/lib.rs | 2 +-
..._comparison.rs => unwrap_cast_in_comparison.rs} | 340 ++++++++++++---------
datafusion/optimizer/tests/integration-test.rs | 4 +-
5 files changed, 208 insertions(+), 162 deletions(-)
diff --git a/datafusion/core/src/execution/context.rs
b/datafusion/core/src/execution/context.rs
index ff0ccf835..2a805a5fc 100644
--- a/datafusion/core/src/execution/context.rs
+++ b/datafusion/core/src/execution/context.rs
@@ -110,10 +110,10 @@ use datafusion_expr::{TableSource, TableType};
use datafusion_optimizer::decorrelate_where_exists::DecorrelateWhereExists;
use datafusion_optimizer::decorrelate_where_in::DecorrelateWhereIn;
use datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys;
-use
datafusion_optimizer::pre_cast_lit_in_comparison::PreCastLitInComparisonExpressions;
use
datafusion_optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
use datafusion_optimizer::scalar_subquery_to_join::ScalarSubqueryToJoin;
use datafusion_optimizer::type_coercion::TypeCoercion;
+use datafusion_optimizer::unwrap_cast_in_comparison::UnwrapCastInComparison;
use datafusion_sql::{
parser::DFParser,
planner::{ContextProvider, SqlToRel},
@@ -1466,9 +1466,9 @@ impl SessionState {
}
let mut rules: Vec<Arc<dyn OptimizerRule + Sync + Send>> = vec![
- Arc::new(PreCastLitInComparisonExpressions::new()),
Arc::new(TypeCoercion::new()),
Arc::new(SimplifyExpressions::new()),
+ Arc::new(UnwrapCastInComparison::new()),
Arc::new(DecorrelateWhereExists::new()),
Arc::new(DecorrelateWhereIn::new()),
Arc::new(ScalarSubqueryToJoin::new()),
diff --git a/datafusion/core/tests/sql/explain_analyze.rs
b/datafusion/core/tests/sql/explain_analyze.rs
index fe51aedc8..7d09d9483 100644
--- a/datafusion/core/tests/sql/explain_analyze.rs
+++ b/datafusion/core/tests/sql/explain_analyze.rs
@@ -767,8 +767,6 @@ async fn test_physical_plan_display_indent_multi_children()
{
#[tokio::test]
#[cfg_attr(tarpaulin, ignore)]
async fn csv_explain() {
- // TODO: https://github.com/apache/arrow-datafusion/issues/3622 refactor
the `PreCastLitInComparisonExpressions`
-
// This test uses the execute function that create full plan cycle:
logical, optimized logical, and physical,
// then execute the physical plan and return the final explain results
let ctx = SessionContext::new();
@@ -779,23 +777,6 @@ async fn csv_explain() {
// Note can't use `assert_batches_eq` as the plan needs to be
// normalized for filenames and number of cores
- let expected = vec![
- vec![
- "logical_plan",
- "Projection: #aggregate_test_100.c1\
- \n Filter: CAST(#aggregate_test_100.c2 AS Int32) > Int32(10)\
- \n TableScan: aggregate_test_100 projection=[c1, c2],
partial_filters=[CAST(#aggregate_test_100.c2 AS Int32) > Int32(10)]"
- ],
- vec!["physical_plan",
- "ProjectionExec: expr=[c1@0 as c1]\
- \n CoalesceBatchesExec: target_batch_size=4096\
- \n FilterExec: CAST(c2@1 AS Int32) > 10\
- \n RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES)\
- \n CsvExec:
files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true,
limit=None, projection=[c1, c2]\
- \n"
- ]];
- assert_eq!(expected, actual);
-
let expected = vec![
vec![
"logical_plan",
@@ -811,6 +792,7 @@ async fn csv_explain() {
\n CsvExec:
files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true,
limit=None, projection=[c1, c2]\
\n"
]];
+ assert_eq!(expected, actual);
let sql = "explain SELECT c1 FROM aggregate_test_100 where c2 > 10";
let actual = execute(&ctx, sql).await;
diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs
index bfb563436..879658c40 100644
--- a/datafusion/optimizer/src/lib.rs
+++ b/datafusion/optimizer/src/lib.rs
@@ -35,9 +35,9 @@ pub mod subquery_filter_to_join;
pub mod type_coercion;
pub mod utils;
-pub mod pre_cast_lit_in_comparison;
pub mod rewrite_disjunctive_predicate;
#[cfg(test)]
pub mod test;
+pub mod unwrap_cast_in_comparison;
pub use optimizer::{OptimizerConfig, OptimizerRule};
diff --git a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs
b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
similarity index 60%
rename from datafusion/optimizer/src/pre_cast_lit_in_comparison.rs
rename to datafusion/optimizer/src/unwrap_cast_in_comparison.rs
index a6d915cf0..0d5665f29 100644
--- a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs
+++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
@@ -15,8 +15,9 @@
// specific language governing permissions and limitations
// under the License.
-//! Pre-cast literal binary comparison rule can be only used to the binary
comparison expr.
-//! It can reduce adding the `Expr::Cast` to the expr instead of adding the
`Expr::Cast` to literal expr.
+//! Unwrap-cast binary comparison rule can be used to the binary/inlist
comparison expr now, and other type
+//! of expr can be added if needed.
+//! This rule can reduce adding the `Expr::Cast` the expr instead of adding
the `Expr::Cast` to literal expr.
use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::{
DataType, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION,
@@ -28,14 +29,14 @@ use datafusion_expr::{
binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator,
};
-/// The rule can be only used to the numeric binary comparison with literal
expr, like below pattern:
-/// `left_expr comparison_op literal_expr` or `literal_expr comparison_op
right_expr`.
-/// The data type of two sides must be signed numeric type now, and will
support more data type later.
+/// The rule can be used to the numeric binary comparison with literal expr,
like below pattern:
+/// `cast(left_expr as data_type) comparison_op literal_expr` or `literal_expr
comparison_op cast(right_expr as data_type)`.
+/// The data type of two sides must be equal, and must be signed numeric type
now, and will support more data type later.
///
/// If the binary comparison expr match above rules, the optimizer will check
if the value of `literal`
/// is in within range(min,max) which is the range(min,max) of the data type
for `left_expr` or `right_expr`.
///
-/// If this true, the literal expr will be casted to the data type of expr on
the other side, and the result of
+/// If this is true, the literal expr will be casted to the data type of expr
on the other side, and the result of
/// binary comparison will be `left_expr comparison_op cast(literal_expr,
left_data_type)` or
/// `cast(literal_expr, right_data_type) comparison_op right_expr`. For better
optimization,
/// the expr of `cast(literal_expr, target_type)` will be precomputed and
converted to the new expr `new_literal_expr`
@@ -45,19 +46,19 @@ use datafusion_expr::{
/// This is inspired by the optimizer rule `UnwrapCastInBinaryComparison` of
Spark.
/// # Example
///
-/// `Filter: c1 > INT64(10)` will be optimized to `Filter: c1 > CAST(INT64(10)
AS INT32),
+/// `Filter: cast(c1 as INT64) > INT64(10)` will be optimized to `Filter: c1 >
CAST(INT64(10) AS INT32),
/// and continue to be converted to `Filter: c1 > INT32(10)`, if the DataType
of c1 is INT32.
///
#[derive(Default)]
-pub struct PreCastLitInComparisonExpressions {}
+pub struct UnwrapCastInComparison {}
-impl PreCastLitInComparisonExpressions {
+impl UnwrapCastInComparison {
pub fn new() -> Self {
Self::default()
}
}
-impl OptimizerRule for PreCastLitInComparisonExpressions {
+impl OptimizerRule for UnwrapCastInComparison {
fn optimize(
&self,
plan: &LogicalPlan,
@@ -67,7 +68,7 @@ impl OptimizerRule for PreCastLitInComparisonExpressions {
}
fn name(&self) -> &str {
- "pre_cast_lit_in_comparison"
+ "unwrap_cast_in_comparison"
}
}
@@ -80,7 +81,7 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
let schema = plan.schema();
- let mut expr_rewriter = PreCastLitExprRewriter {
+ let mut expr_rewriter = UnwrapCastExprRewriter {
schema: schema.clone(),
};
@@ -93,17 +94,20 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
from_plan(plan, new_exprs.as_slice(), new_inputs.as_slice())
}
-struct PreCastLitExprRewriter {
+struct UnwrapCastExprRewriter {
schema: DFSchemaRef,
}
-impl ExprRewriter for PreCastLitExprRewriter {
+impl ExprRewriter for UnwrapCastExprRewriter {
fn pre_visit(&mut self, _expr: &Expr) -> Result<RewriteRecursion> {
Ok(RewriteRecursion::Continue)
}
fn mutate(&mut self, expr: Expr) -> Result<Expr> {
match &expr {
+ // For case:
+ // try_cast/cast(expr as data_type) op literal
+ // literal op try_cast/cast(expr as data_type)
Expr::BinaryExpr { left, op, right } => {
let left = left.as_ref().clone();
let right = right.as_ref().clone();
@@ -113,29 +117,48 @@ impl ExprRewriter for PreCastLitExprRewriter {
if left_type.is_err() || right_type.is_err() {
return Ok(expr.clone());
}
+ // Because the plan has been done the type coercion, the left
and right must be equal
let left_type = left_type?;
let right_type = right_type?;
- if !left_type.eq(&right_type)
- && is_support_data_type(&left_type)
+ if is_support_data_type(&left_type)
&& is_support_data_type(&right_type)
&& is_comparison_op(op)
{
match (&left, &right) {
- (Expr::Literal(_), Expr::Literal(_)) => {
- // do nothing
- }
- (Expr::Literal(left_lit_value), _) => {
+ (
+ Expr::Literal(left_lit_value),
+ Expr::TryCast { expr, .. } | Expr::Cast { expr, ..
},
+ ) => {
+ // if the left_lit_value can be casted to the type
of expr
+ // we need to unwrap the cast for cast/try_cast
expr, and add cast to the literal
+ let expr_type = expr.get_type(&self.schema)?;
let casted_scalar_value =
- try_cast_literal_to_type(left_lit_value,
&right_type)?;
+ try_cast_literal_to_type(left_lit_value,
&expr_type)?;
if let Some(value) = casted_scalar_value {
- return Ok(binary_expr(lit(value), *op, right));
+ // unwrap the cast/try_cast for the right expr
+ return Ok(binary_expr(
+ lit(value),
+ *op,
+ expr.as_ref().clone(),
+ ));
}
}
- (_, Expr::Literal(right_lit_value)) => {
+ (
+ Expr::TryCast { expr, .. } | Expr::Cast { expr, ..
},
+ Expr::Literal(right_lit_value),
+ ) => {
+ // if the right_lit_value can be casted to the
type of expr
+ // we need to unwrap the cast for cast/try_cast
expr, and add cast to the literal
+ let expr_type = expr.get_type(&self.schema)?;
let casted_scalar_value =
- try_cast_literal_to_type(right_lit_value,
&left_type)?;
+ try_cast_literal_to_type(right_lit_value,
&expr_type)?;
if let Some(value) = casted_scalar_value {
- return Ok(binary_expr(left, *op, lit(value)));
+ // unwrap the cast/try_cast for the left expr
+ return Ok(binary_expr(
+ expr.as_ref().clone(),
+ *op,
+ lit(value),
+ ));
}
}
(_, _) => {
@@ -146,55 +169,75 @@ impl ExprRewriter for PreCastLitExprRewriter {
// return the new binary op
Ok(binary_expr(left, *op, right))
}
+ // For case:
+ // try_cast/cast(expr as left_type) in (expr1,expr2,expr3)
Expr::InList {
expr: left_expr,
list,
negated,
} => {
- let left = left_expr.as_ref().clone();
- let left_type = left.get_type(&self.schema);
- if left_type.is_err() {
- // error data type
- return Ok(expr);
- }
- let left_type = left_type?;
- if !is_support_data_type(&left_type) {
- // not supported data type
- return Ok(expr);
- }
- let right_exprs = list
- .iter()
- .map(|right| {
- let right_type = right.get_type(&self.schema)?;
- if !is_support_data_type(&right_type) {
- return Err(DataFusionError::Internal(format!(
- "The type of list expr {} not support",
- &right_type
- )));
- }
- match right {
- Expr::Literal(right_lit_value) => {
- let casted_scalar_value =
- try_cast_literal_to_type(right_lit_value,
&left_type)?;
- if let Some(value) = casted_scalar_value {
- Ok(lit(value))
- } else {
- Err(DataFusionError::Internal(format!(
- "Can't cast the list expr {:?} to type
{:?}",
- right_lit_value, &left_type
- )))
+ if let Some(
+ Expr::TryCast {
+ expr: internal_left_expr,
+ ..
+ }
+ | Expr::Cast {
+ expr: internal_left_expr,
+ ..
+ },
+ ) = Some(left_expr.as_ref())
+ {
+ let internal_left = internal_left_expr.as_ref().clone();
+ let internal_left_type =
internal_left.get_type(&self.schema);
+ if internal_left_type.is_err() {
+ // error data type
+ return Ok(expr);
+ }
+ let internal_left_type = internal_left_type?;
+ if !is_support_data_type(&internal_left_type) {
+ // not supported data type
+ return Ok(expr);
+ }
+ let right_exprs = list
+ .iter()
+ .map(|right| {
+ let right_type = right.get_type(&self.schema)?;
+ if !is_support_data_type(&right_type) {
+ return Err(DataFusionError::Internal(format!(
+ "The type of list expr {} not support",
+ &right_type
+ )));
+ }
+ match right {
+ Expr::Literal(right_lit_value) => {
+ // if the right_lit_value can be casted to
the type of internal_left_expr
+ // we need to unwrap the cast for
cast/try_cast expr, and add cast to the literal
+ let casted_scalar_value =
+
try_cast_literal_to_type(right_lit_value, &internal_left_type)?;
+ if let Some(value) = casted_scalar_value {
+ Ok(lit(value))
+ } else {
+ Err(DataFusionError::Internal(format!(
+ "Can't cast the list expr {:?} to
type {:?}",
+ right_lit_value,
&internal_left_type
+ )))
+ }
}
+ other_expr =>
Err(DataFusionError::Internal(format!(
+ "Only support literal expr to optimize,
but the expr is {:?}",
+ &other_expr
+ ))),
}
- other_expr =>
Err(DataFusionError::Internal(format!(
- "Only support literal expr to optimize, but
the expr is {:?}",
- &other_expr
- ))),
+ })
+ .collect::<Result<Vec<_>>>();
+ match right_exprs {
+ Ok(right_exprs) => {
+ Ok(in_list(internal_left, right_exprs, *negated))
}
- })
- .collect::<Result<Vec<_>>>();
- match right_exprs {
- Ok(right_exprs) => Ok(in_list(left, right_exprs,
*negated)),
- Err(_) => Ok(expr),
+ Err(_) => Ok(expr),
+ }
+ } else {
+ Ok(expr)
}
}
// TODO: handle other expr type and dfs visit them
@@ -326,23 +369,19 @@ fn try_cast_literal_to_type(
#[cfg(test)]
mod tests {
- use crate::pre_cast_lit_in_comparison::PreCastLitExprRewriter;
+ use crate::unwrap_cast_in_comparison::UnwrapCastExprRewriter;
use arrow::datatypes::DataType;
use datafusion_common::{DFField, DFSchema, DFSchemaRef, ScalarValue};
use datafusion_expr::expr_rewriter::ExprRewritable;
- use datafusion_expr::{col, lit, Expr};
+ use datafusion_expr::{cast, col, lit, try_cast, Expr};
use std::collections::HashMap;
use std::sync::Arc;
#[test]
- fn test_not_cast_lit_comparison() {
+ fn test_not_unwrap_cast_comparison() {
let schema = expr_test_schema();
- // INT8(NULL) < INT32(12)
- let lit_lt_lit =
- lit(ScalarValue::Int8(None)).lt(lit(ScalarValue::Int32(Some(12))));
- assert_eq!(optimize_test(lit_lt_lit.clone(), &schema), lit_lt_lit);
- // INT32(c1) > INT64(c2)
- let c1_gt_c2 = col("c1").gt(col("c2"));
+ // cast(INT32(c1), INT64) > INT64(c2)
+ let c1_gt_c2 = cast(col("c1"), DataType::Int64).gt(col("c2"));
assert_eq!(optimize_test(c1_gt_c2.clone(), &schema), c1_gt_c2);
// INT32(c1) < INT32(16), the type is same
@@ -350,110 +389,132 @@ mod tests {
assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
// the 99999999999 is not within the range of MAX(int32) and
MIN(int32), we don't cast the lit(99999999999) to int32 type
- let expr_lt = col("c1").lt(lit(ScalarValue::Int64(Some(99999999999))));
+ let expr_lt = cast(col("c1"), DataType::Int64)
+ .lt(lit(ScalarValue::Int64(Some(99999999999))));
assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
}
#[test]
- fn test_pre_cast_lit_comparison() {
+ fn test_unwrap_cast_comparison() {
let schema = expr_test_schema();
- // c1 < INT64(16) -> c1 < cast(INT32(16))
+ // cast(c1, INT64) < INT64(16) -> INT32(c1) < cast(INT32(16))
// the 16 is within the range of MAX(int32) and MIN(int32), we can
cast the 16 to int32(16)
- let expr_lt = col("c1").lt(lit(ScalarValue::Int64(Some(16))));
+ let expr_lt =
+ cast(col("c1"),
DataType::Int64).lt(lit(ScalarValue::Int64(Some(16))));
+ let expected = col("c1").lt(lit(ScalarValue::Int32(Some(16))));
+ assert_eq!(optimize_test(expr_lt, &schema), expected);
+ let expr_lt =
+ try_cast(col("c1"),
DataType::Int64).lt(lit(ScalarValue::Int64(Some(16))));
let expected = col("c1").lt(lit(ScalarValue::Int32(Some(16))));
assert_eq!(optimize_test(expr_lt, &schema), expected);
- // INT64(c2) = INT32(16) => INT64(c2) = INT64(16)
- let c2_eq_lit = col("c2").eq(lit(ScalarValue::Int32(Some(16))));
+ // cast(c2, INT32) = INT32(16) => INT64(c2) = INT64(16)
+ let c2_eq_lit =
+ cast(col("c2"),
DataType::Int32).eq(lit(ScalarValue::Int32(Some(16))));
let expected = col("c2").eq(lit(ScalarValue::Int64(Some(16))));
assert_eq!(optimize_test(c2_eq_lit, &schema), expected);
- // INT32(c1) < INT64(NULL) => INT32(c1) < INT32(NULL)
- let c1_lt_lit_null = col("c1").lt(lit(ScalarValue::Int64(None)));
+ // cast(c1, INT64) < INT64(NULL) => INT32(c1) < INT32(NULL)
+ let c1_lt_lit_null =
+ cast(col("c1"), DataType::Int64).lt(lit(ScalarValue::Int64(None)));
let expected = col("c1").lt(lit(ScalarValue::Int32(None)));
assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected);
+
+ // cast(INT8(NULL), INT32) < INT32(12) => INT8(NULL) < INT8(12)
+ let lit_lt_lit = cast(lit(ScalarValue::Int8(None)), DataType::Int32)
+ .lt(lit(ScalarValue::Int32(Some(12))));
+ let expected =
lit(ScalarValue::Int8(None)).lt(lit(ScalarValue::Int8(Some(12))));
+ assert_eq!(optimize_test(lit_lt_lit, &schema), expected);
}
#[test]
- fn test_not_cast_with_decimal_lit_comparison() {
+ fn test_not_unwrap_cast_with_decimal_comparison() {
let schema = expr_test_schema();
// integer to decimal: value is out of the bounds of the decimal
- // c3 = INT64(100000000000000000)
- let expr_eq =
col("c3").eq(lit(ScalarValue::Int64(Some(100000000000000000))));
- let expected =
col("c3").eq(lit(ScalarValue::Int64(Some(100000000000000000))));
- assert_eq!(optimize_test(expr_eq, &schema), expected);
- // c4 = INT64(1000) will overflow the i128
- let expr_eq = col("c4").eq(lit(ScalarValue::Int64(Some(1000))));
- let expected = col("c4").eq(lit(ScalarValue::Int64(Some(1000))));
- assert_eq!(optimize_test(expr_eq, &schema), expected);
+ // cast(c3, INT64) = INT64(100000000000000000)
+ let expr_eq = cast(col("c3"), DataType::Int64)
+ .eq(lit(ScalarValue::Int64(Some(100000000000000000))));
+ assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq);
+
+ // cast(c4, INT64) = INT64(1000) will overflow the i128
+ let expr_eq =
+ cast(col("c4"),
DataType::Int64).eq(lit(ScalarValue::Int64(Some(1000))));
+ assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq);
// decimal to decimal: value will lose the scale when convert to the
target data type
// c3 = DECIMAL(12340,20,4)
- let expr_eq = col("c3").eq(lit(ScalarValue::Decimal128(Some(12340),
20, 4)));
- let expected = col("c3").eq(lit(ScalarValue::Decimal128(Some(12340),
20, 4)));
- assert_eq!(optimize_test(expr_eq, &schema), expected);
+ let expr_eq = cast(col("c3"), DataType::Decimal128(20, 4))
+ .eq(lit(ScalarValue::Decimal128(Some(12340), 20, 4)));
+ assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq);
// decimal to integer
// c1 = DECIMAL(123, 10, 1): value will lose the scale when convert to
the target data type
- let expr_eq = col("c1").eq(lit(ScalarValue::Decimal128(Some(123), 10,
1)));
- let expected = col("c1").eq(lit(ScalarValue::Decimal128(Some(123), 10,
1)));
- assert_eq!(optimize_test(expr_eq, &schema), expected);
+ let expr_eq = cast(col("c1"), DataType::Decimal128(10, 1))
+ .eq(lit(ScalarValue::Decimal128(Some(123), 10, 1)));
+ assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq);
+
// c1 = DECIMAL(1230, 10, 2): value will lose the scale when convert
to the target data type
- let expr_eq = col("c1").eq(lit(ScalarValue::Decimal128(Some(1230), 10,
2)));
- let expected = col("c1").eq(lit(ScalarValue::Decimal128(Some(1230),
10, 2)));
- assert_eq!(optimize_test(expr_eq, &schema), expected);
+ let expr_eq = cast(col("c1"), DataType::Decimal128(10, 2))
+ .eq(lit(ScalarValue::Decimal128(Some(1230), 10, 2)));
+ assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq);
}
#[test]
- fn test_pre_cast_with_decimal_lit_comparison() {
+ fn test_unwrap_cast_with_decimal_lit_comparison() {
let schema = expr_test_schema();
// integer to decimal
// c3 < INT64(16) -> c3 < (CAST(INT64(16) AS DECIMAL(18,2));
- let expr_lt = col("c3").lt(lit(ScalarValue::Int64(Some(16))));
+ let expr_lt =
+ try_cast(col("c3"),
DataType::Int64).lt(lit(ScalarValue::Int64(Some(16))));
let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(1600),
18, 2)));
assert_eq!(optimize_test(expr_lt, &schema), expected);
// c3 < INT64(NULL)
- let c1_lt_lit_null = col("c3").lt(lit(ScalarValue::Int64(None)));
+ let c1_lt_lit_null =
+ cast(col("c3"), DataType::Int64).lt(lit(ScalarValue::Int64(None)));
let expected = col("c3").lt(lit(ScalarValue::Decimal128(None, 18, 2)));
assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected);
// decimal to decimal
// c3 < Decimal(123,10,0) -> c3 < CAST(DECIMAL(123,10,0) AS
DECIMAL(18,2)) -> c3 < DECIMAL(12300,18,2)
- let expr_lt = col("c3").lt(lit(ScalarValue::Decimal128(Some(123), 10,
0)));
+ let expr_lt = cast(col("c3"), DataType::Decimal128(10, 0))
+ .lt(lit(ScalarValue::Decimal128(Some(123), 10, 0)));
let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(12300),
18, 2)));
assert_eq!(optimize_test(expr_lt, &schema), expected);
+
// c3 < Decimal(1230,10,3) -> c3 < CAST(DECIMAL(1230,10,3) AS
DECIMAL(18,2)) -> c3 < DECIMAL(123,18,2)
- let expr_lt = col("c3").lt(lit(ScalarValue::Decimal128(Some(1230), 10,
3)));
+ let expr_lt = cast(col("c3"), DataType::Decimal128(10, 3))
+ .lt(lit(ScalarValue::Decimal128(Some(1230), 10, 3)));
let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(123), 18,
2)));
assert_eq!(optimize_test(expr_lt, &schema), expected);
// decimal to integer
// c1 < Decimal(12300, 10, 2) -> c1 < CAST(DECIMAL(12300,10,2) AS
INT32) -> c1 < INT32(123)
- let expr_lt = col("c1").lt(lit(ScalarValue::Decimal128(Some(12300),
10, 2)));
+ let expr_lt = cast(col("c1"), DataType::Decimal128(10, 2))
+ .lt(lit(ScalarValue::Decimal128(Some(12300), 10, 2)));
let expected = col("c1").lt(lit(ScalarValue::Int32(Some(123))));
assert_eq!(optimize_test(expr_lt, &schema), expected);
}
#[test]
- fn test_not_list_cast_lit_comparison() {
+ fn test_not_unwrap_list_cast_lit_comparison() {
let schema = expr_test_schema();
- // left type is not supported
+ // internal left type is not supported
// FLOAT32(C5) in ...
- let expr_lt = col("c5").in_list(
+ let expr_lt = cast(col("c5"), DataType::Int64).in_list(
vec![
lit(ScalarValue::Int64(Some(12))),
- lit(ScalarValue::Int32(Some(12))),
+ lit(ScalarValue::Int64(Some(12))),
],
false,
);
assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
- // INT32(C1) in (FLOAT32(1.23), INT32(12), INT64(12))
- let expr_lt = col("c1").in_list(
+ // cast(INT32(C1), Float32) in (FLOAT32(1.23), Float32(12),
Float32(12))
+ let expr_lt = cast(col("c1"), DataType::Float32).in_list(
vec![
- lit(ScalarValue::Int32(Some(12))),
- lit(ScalarValue::Int64(Some(12))),
+ lit(ScalarValue::Float32(Some(12.0))),
+ lit(ScalarValue::Float32(Some(12.0))),
lit(ScalarValue::Float32(Some(1.23))),
],
false,
@@ -461,7 +522,7 @@ mod tests {
assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
// INT32(C1) in (INT64(99999999999), INT64(12))
- let expr_lt = col("c1").in_list(
+ let expr_lt = cast(col("c1"), DataType::Int64).in_list(
vec![
lit(ScalarValue::Int32(Some(12))),
lit(ScalarValue::Int64(Some(99999999999))),
@@ -471,10 +532,10 @@ mod tests {
assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
// DECIMAL(C3) in (INT64(12), INT32(12), DECIMAL(128,12,3))
- let expr_lt = col("c3").in_list(
+ let expr_lt = cast(col("c3"), DataType::Decimal128(12, 3)).in_list(
vec![
- lit(ScalarValue::Int32(Some(12))),
- lit(ScalarValue::Int64(Some(12))),
+ lit(ScalarValue::Decimal128(Some(12), 12, 3)),
+ lit(ScalarValue::Decimal128(Some(12), 12, 3)),
lit(ScalarValue::Decimal128(Some(128), 12, 3)),
],
false,
@@ -483,12 +544,12 @@ mod tests {
}
#[test]
- fn test_pre_list_cast_lit_comparison() {
+ fn test_unwrap_list_cast_comparison() {
let schema = expr_test_schema();
// INT32(C1) IN (INT32(12),INT64(24)) -> INT32(C1) IN
(INT32(12),INT32(24))
- let expr_lt = col("c1").in_list(
+ let expr_lt = cast(col("c1"), DataType::Int64).in_list(
vec![
- lit(ScalarValue::Int32(Some(12))),
+ lit(ScalarValue::Int64(Some(12))),
lit(ScalarValue::Int64(Some(24))),
],
false,
@@ -502,9 +563,9 @@ mod tests {
);
assert_eq!(optimize_test(expr_lt, &schema), expected);
// INT32(C2) IN (INT64(NULL),INT64(24)) -> INT32(C1) IN
(INT32(12),INT32(24))
- let expr_lt = col("c2").in_list(
+ let expr_lt = cast(col("c2"), DataType::Int32).in_list(
vec![
- lit(ScalarValue::Int64(None)),
+ lit(ScalarValue::Int32(None)),
lit(ScalarValue::Int32(Some(14))),
],
false,
@@ -520,12 +581,13 @@ mod tests {
assert_eq!(optimize_test(expr_lt, &schema), expected);
// decimal test case
- let expr_lt = col("c3").in_list(
+ // c3 is decimal(18,2)
+ let expr_lt = cast(col("c3"), DataType::Decimal128(19, 3)).in_list(
vec![
- lit(ScalarValue::Int32(Some(12))),
- lit(ScalarValue::Int64(Some(24))),
- lit(ScalarValue::Decimal128(Some(128), 10, 2)),
- lit(ScalarValue::Decimal128(Some(1280), 10, 3)),
+ lit(ScalarValue::Decimal128(Some(12000), 19, 3)),
+ lit(ScalarValue::Decimal128(Some(24000), 19, 3)),
+ lit(ScalarValue::Decimal128(Some(1280), 19, 3)),
+ lit(ScalarValue::Decimal128(Some(1240), 19, 3)),
],
false,
);
@@ -534,23 +596,23 @@ mod tests {
lit(ScalarValue::Decimal128(Some(1200), 18, 2)),
lit(ScalarValue::Decimal128(Some(2400), 18, 2)),
lit(ScalarValue::Decimal128(Some(128), 18, 2)),
- lit(ScalarValue::Decimal128(Some(128), 18, 2)),
+ lit(ScalarValue::Decimal128(Some(124), 18, 2)),
],
false,
);
assert_eq!(optimize_test(expr_lt, &schema), expected);
- // INT32(12) IN (.....)
- let expr_lt = lit(ScalarValue::Int32(Some(12))).in_list(
+ // cast(INT32(12), INT64) IN (.....)
+ let expr_lt = cast(lit(ScalarValue::Int32(Some(12))),
DataType::Int64).in_list(
vec![
- lit(ScalarValue::Int32(Some(12))),
+ lit(ScalarValue::Int64(Some(13))),
lit(ScalarValue::Int64(Some(12))),
],
false,
);
let expected = lit(ScalarValue::Int32(Some(12))).in_list(
vec![
- lit(ScalarValue::Int32(Some(12))),
+ lit(ScalarValue::Int32(Some(13))),
lit(ScalarValue::Int32(Some(12))),
],
false,
@@ -563,7 +625,9 @@ mod tests {
let schema = expr_test_schema();
// c1 < INT64(16) -> c1 < cast(INT32(16))
// the 16 is within the range of MAX(int32) and MIN(int32), we can
cast the 16 to int32(16)
- let expr_lt =
col("c1").lt(lit(ScalarValue::Int64(Some(16)))).alias("x");
+ let expr_lt = cast(col("c1"), DataType::Int64)
+ .lt(lit(ScalarValue::Int64(Some(16))))
+ .alias("x");
let expected =
col("c1").lt(lit(ScalarValue::Int32(Some(16)))).alias("x");
assert_eq!(optimize_test(expr_lt, &schema), expected);
}
@@ -573,9 +637,9 @@ mod tests {
let schema = expr_test_schema();
// c1 < INT64(16) OR c1 > INT64(32) -> c1 < INT32(16) OR c1 > INT32(32)
// the 16 and 32 are within the range of MAX(int32) and MIN(int32), we
can cast them to int32
- let expr_lt = col("c1")
+ let expr_lt = cast(col("c1"), DataType::Int64)
.lt(lit(ScalarValue::Int64(Some(16))))
- .or(col("c1").gt(lit(ScalarValue::Int64(Some(32)))));
+ .or(cast(col("c1"),
DataType::Int64).gt(lit(ScalarValue::Int64(Some(32)))));
let expected = col("c1")
.lt(lit(ScalarValue::Int32(Some(16))))
.or(col("c1").gt(lit(ScalarValue::Int32(Some(32)))));
@@ -583,7 +647,7 @@ mod tests {
}
fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr {
- let mut expr_rewriter = PreCastLitExprRewriter {
+ let mut expr_rewriter = UnwrapCastExprRewriter {
schema: schema.clone(),
};
expr.rewrite(&mut expr_rewriter).unwrap()
diff --git a/datafusion/optimizer/tests/integration-test.rs
b/datafusion/optimizer/tests/integration-test.rs
index 61bfafed7..7811e475c 100644
--- a/datafusion/optimizer/tests/integration-test.rs
+++ b/datafusion/optimizer/tests/integration-test.rs
@@ -27,7 +27,6 @@ use
datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys;
use datafusion_optimizer::filter_push_down::FilterPushDown;
use datafusion_optimizer::limit_push_down::LimitPushDown;
use datafusion_optimizer::optimizer::Optimizer;
-use
datafusion_optimizer::pre_cast_lit_in_comparison::PreCastLitInComparisonExpressions;
use datafusion_optimizer::projection_push_down::ProjectionPushDown;
use datafusion_optimizer::reduce_cross_join::ReduceCrossJoin;
use datafusion_optimizer::reduce_outer_join::ReduceOuterJoin;
@@ -37,6 +36,7 @@ use
datafusion_optimizer::simplify_expressions::SimplifyExpressions;
use datafusion_optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy;
use datafusion_optimizer::subquery_filter_to_join::SubqueryFilterToJoin;
use datafusion_optimizer::type_coercion::TypeCoercion;
+use datafusion_optimizer::unwrap_cast_in_comparison::UnwrapCastInComparison;
use datafusion_optimizer::{OptimizerConfig, OptimizerRule};
use datafusion_sql::planner::{ContextProvider, SqlToRel};
use datafusion_sql::sqlparser::ast::Statement;
@@ -107,9 +107,9 @@ fn test_sql(sql: &str) -> Result<LogicalPlan> {
// TODO should make align with rules in the context
// https://github.com/apache/arrow-datafusion/issues/3524
let rules: Vec<Arc<dyn OptimizerRule + Sync + Send>> = vec![
- Arc::new(PreCastLitInComparisonExpressions::new()),
Arc::new(TypeCoercion::new()),
Arc::new(SimplifyExpressions::new()),
+ Arc::new(UnwrapCastInComparison::new()),
Arc::new(DecorrelateWhereExists::new()),
Arc::new(DecorrelateWhereIn::new()),
Arc::new(ScalarSubqueryToJoin::new()),