This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 84153682b [MINOR] Add `ScalarValue::new_utf8`, clean up creation of
literals in casting tests (#3680)
84153682b is described below
commit 84153682b5692e786292c672798609adaeb481f4
Author: Andrew Lamb <[email protected]>
AuthorDate: Mon Oct 3 12:07:42 2022 -0400
[MINOR] Add `ScalarValue::new_utf8`, clean up creation of literals in
casting tests (#3680)
* Add ScalarValue::new_utf8, clean up construction in tests
* Cleanup
* some more minor cleanup
* Update unwrap_cast
---
datafusion/common/src/scalar.rs | 7 +-
datafusion/optimizer/src/type_coercion.rs | 32 +--
.../optimizer/src/unwrap_cast_in_comparison.rs | 215 +++++++++------------
3 files changed, 102 insertions(+), 152 deletions(-)
diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs
index 2f8f12558..42f0a7d16 100644
--- a/datafusion/common/src/scalar.rs
+++ b/datafusion/common/src/scalar.rs
@@ -54,7 +54,7 @@ pub enum ScalarValue {
Float32(Option<f32>),
/// 64bit float
Float64(Option<f64>),
- /// 128bit decimal, using the i128 to represent the decimal
+ /// 128bit decimal, using the i128 to represent the decimal, precision
scale
Decimal128(Option<i128>, u8, u8),
/// signed 8bit int
Int8(Option<i8>),
@@ -816,6 +816,11 @@ impl ScalarValue {
)))
}
+ /// Returns a [`ScalarValue::Utf8`] representing `val`
+ pub fn new_utf8(val: impl Into<String>) -> Self {
+ ScalarValue::Utf8(Some(val.into()))
+ }
+
/// Returns a [`ScalarValue::IntervalYearMonth`] representing
/// `years` years and `months` months
pub fn new_interval_ym(years: i32, months: i32) -> Self {
diff --git a/datafusion/optimizer/src/type_coercion.rs
b/datafusion/optimizer/src/type_coercion.rs
index 3d45c5041..5a53adc26 100644
--- a/datafusion/optimizer/src/type_coercion.rs
+++ b/datafusion/optimizer/src/type_coercion.rs
@@ -421,8 +421,8 @@ mod test {
use datafusion_expr::{
lit,
logical_plan::{EmptyRelation, Projection},
- Expr, LogicalPlan, Operator, ReturnTypeFunction,
ScalarFunctionImplementation,
- ScalarUDF, Signature, Volatility,
+ Expr, LogicalPlan, ReturnTypeFunction, ScalarFunctionImplementation,
ScalarUDF,
+ Signature, Volatility,
};
use std::sync::Arc;
@@ -484,11 +484,8 @@ mod test {
let empty = empty();
let return_type: ReturnTypeFunction =
Arc::new(move |_| Ok(Arc::new(DataType::Utf8)));
- let fun: ScalarFunctionImplementation = Arc::new(move |_| {
- Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(
- "a".to_string(),
- ))))
- });
+ let fun: ScalarFunctionImplementation =
+ Arc::new(move |_|
Ok(ColumnarValue::Scalar(ScalarValue::new_utf8("a"))));
let udf = Expr::ScalarUDF {
fun: Arc::new(ScalarUDF::new(
"TestScalarUDF",
@@ -538,16 +535,8 @@ mod test {
#[test]
fn binary_op_date32_add_interval() -> Result<()> {
//CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("386547056640")
- let expr = Expr::BinaryExpr {
- left: Box::new(Expr::Cast {
- expr: Box::new(lit("1998-03-18")),
- data_type: DataType::Date32,
- }),
- op: Operator::Plus,
- right: Box::new(Expr::Literal(ScalarValue::IntervalDayTime(Some(
- 386547056640,
- )))),
- };
+ let expr = cast(lit("1998-03-18"), DataType::Date32)
+ + lit(ScalarValue::IntervalDayTime(Some(386547056640)));
let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::new(DFSchema::empty()),
@@ -665,7 +654,7 @@ mod test {
fn like_for_type_coercion() -> Result<()> {
// like : utf8 like "abc"
let expr = Box::new(col("a"));
- let pattern =
Box::new(lit(ScalarValue::Utf8(Some("abc".to_string()))));
+ let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
let like_expr = Expr::Like {
negated: false,
expr,
@@ -703,7 +692,7 @@ mod test {
);
let expr = Box::new(col("a"));
- let pattern =
Box::new(lit(ScalarValue::Utf8(Some("abc".to_string()))));
+ let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
let like_expr = Expr::Like {
negated: false,
expr,
@@ -792,10 +781,7 @@ mod test {
);
let mut rewriter = TypeCoercionRewriter { schema };
let expr = is_true(lit(12i32).eq(lit(13i64)));
- let expected = is_true(
- cast(lit(ScalarValue::Int32(Some(12))), DataType::Int64)
- .eq(lit(ScalarValue::Int64(Some(13)))),
- );
+ let expected = is_true(cast(lit(12i32),
DataType::Int64).eq(lit(13i64)));
let result = expr.rewrite(&mut rewriter)?;
assert_eq!(expected, result);
Ok(())
diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
index 0d5665f29..0f7238d33 100644
--- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
+++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
@@ -385,12 +385,11 @@ mod tests {
assert_eq!(optimize_test(c1_gt_c2.clone(), &schema), c1_gt_c2);
// INT32(c1) < INT32(16), the type is same
- let expr_lt = col("c1").lt(lit(ScalarValue::Int32(Some(16))));
+ let expr_lt = col("c1").lt(lit(16i32));
assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
// the 99999999999 is not within the range of MAX(int32) and
MIN(int32), we don't cast the lit(99999999999) to int32 type
- let expr_lt = cast(col("c1"), DataType::Int64)
- .lt(lit(ScalarValue::Int64(Some(99999999999))));
+ let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(99999999999i64));
assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
}
@@ -399,31 +398,26 @@ mod tests {
let schema = expr_test_schema();
// cast(c1, INT64) < INT64(16) -> INT32(c1) < cast(INT32(16))
// the 16 is within the range of MAX(int32) and MIN(int32), we can
cast the 16 to int32(16)
- let expr_lt =
- cast(col("c1"),
DataType::Int64).lt(lit(ScalarValue::Int64(Some(16))));
- let expected = col("c1").lt(lit(ScalarValue::Int32(Some(16))));
+ let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(16i64));
+ let expected = col("c1").lt(lit(16i32));
assert_eq!(optimize_test(expr_lt, &schema), expected);
- let expr_lt =
- try_cast(col("c1"),
DataType::Int64).lt(lit(ScalarValue::Int64(Some(16))));
- let expected = col("c1").lt(lit(ScalarValue::Int32(Some(16))));
+ let expr_lt = try_cast(col("c1"), DataType::Int64).lt(lit(16i64));
+ let expected = col("c1").lt(lit(16i32));
assert_eq!(optimize_test(expr_lt, &schema), expected);
// cast(c2, INT32) = INT32(16) => INT64(c2) = INT64(16)
- let c2_eq_lit =
- cast(col("c2"),
DataType::Int32).eq(lit(ScalarValue::Int32(Some(16))));
- let expected = col("c2").eq(lit(ScalarValue::Int64(Some(16))));
+ let c2_eq_lit = cast(col("c2"), DataType::Int32).eq(lit(16i32));
+ let expected = col("c2").eq(lit(16i64));
assert_eq!(optimize_test(c2_eq_lit, &schema), expected);
// cast(c1, INT64) < INT64(NULL) => INT32(c1) < INT32(NULL)
- let c1_lt_lit_null =
- cast(col("c1"), DataType::Int64).lt(lit(ScalarValue::Int64(None)));
- let expected = col("c1").lt(lit(ScalarValue::Int32(None)));
+ let c1_lt_lit_null = cast(col("c1"), DataType::Int64).lt(null_i64());
+ let expected = col("c1").lt(null_i32());
assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected);
// cast(INT8(NULL), INT32) < INT32(12) => INT8(NULL) < INT8(12)
- let lit_lt_lit = cast(lit(ScalarValue::Int8(None)), DataType::Int32)
- .lt(lit(ScalarValue::Int32(Some(12))));
- let expected =
lit(ScalarValue::Int8(None)).lt(lit(ScalarValue::Int8(Some(12))));
+ let lit_lt_lit = cast(null_i8(), DataType::Int32).lt(lit(12i32));
+ let expected = null_i8().lt(lit(12i8));
assert_eq!(optimize_test(lit_lt_lit, &schema), expected);
}
@@ -432,30 +426,28 @@ mod tests {
let schema = expr_test_schema();
// integer to decimal: value is out of the bounds of the decimal
// cast(c3, INT64) = INT64(100000000000000000)
- let expr_eq = cast(col("c3"), DataType::Int64)
- .eq(lit(ScalarValue::Int64(Some(100000000000000000))));
+ let expr_eq = cast(col("c3"),
DataType::Int64).eq(lit(100000000000000000i64));
assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq);
// cast(c4, INT64) = INT64(1000) will overflow the i128
- let expr_eq =
- cast(col("c4"),
DataType::Int64).eq(lit(ScalarValue::Int64(Some(1000))));
+ let expr_eq = cast(col("c4"), DataType::Int64).eq(lit(1000i64));
assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq);
// decimal to decimal: value will lose the scale when convert to the
target data type
// c3 = DECIMAL(12340,20,4)
- let expr_eq = cast(col("c3"), DataType::Decimal128(20, 4))
- .eq(lit(ScalarValue::Decimal128(Some(12340), 20, 4)));
+ let expr_eq =
+ cast(col("c3"), DataType::Decimal128(20, 4)).eq(lit_decimal(12340,
20, 4));
assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq);
// decimal to integer
// c1 = DECIMAL(123, 10, 1): value will lose the scale when convert to
the target data type
- let expr_eq = cast(col("c1"), DataType::Decimal128(10, 1))
- .eq(lit(ScalarValue::Decimal128(Some(123), 10, 1)));
+ let expr_eq =
+ cast(col("c1"), DataType::Decimal128(10, 1)).eq(lit_decimal(123,
10, 1));
assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq);
// c1 = DECIMAL(1230, 10, 2): value will lose the scale when convert
to the target data type
- let expr_eq = cast(col("c1"), DataType::Decimal128(10, 2))
- .eq(lit(ScalarValue::Decimal128(Some(1230), 10, 2)));
+ let expr_eq =
+ cast(col("c1"), DataType::Decimal128(10, 2)).eq(lit_decimal(1230,
10, 2));
assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq);
}
@@ -464,35 +456,33 @@ mod tests {
let schema = expr_test_schema();
// integer to decimal
// c3 < INT64(16) -> c3 < (CAST(INT64(16) AS DECIMAL(18,2));
- let expr_lt =
- try_cast(col("c3"),
DataType::Int64).lt(lit(ScalarValue::Int64(Some(16))));
- let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(1600),
18, 2)));
+ let expr_lt = try_cast(col("c3"), DataType::Int64).lt(lit(16i64));
+ let expected = col("c3").lt(lit_decimal(1600, 18, 2));
assert_eq!(optimize_test(expr_lt, &schema), expected);
// c3 < INT64(NULL)
- let c1_lt_lit_null =
- cast(col("c3"), DataType::Int64).lt(lit(ScalarValue::Int64(None)));
- let expected = col("c3").lt(lit(ScalarValue::Decimal128(None, 18, 2)));
+ let c1_lt_lit_null = cast(col("c3"), DataType::Int64).lt(null_i64());
+ let expected = col("c3").lt(null_decimal(18, 2));
assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected);
// decimal to decimal
// c3 < Decimal(123,10,0) -> c3 < CAST(DECIMAL(123,10,0) AS
DECIMAL(18,2)) -> c3 < DECIMAL(12300,18,2)
- let expr_lt = cast(col("c3"), DataType::Decimal128(10, 0))
- .lt(lit(ScalarValue::Decimal128(Some(123), 10, 0)));
- let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(12300),
18, 2)));
+ let expr_lt =
+ cast(col("c3"), DataType::Decimal128(10, 0)).lt(lit_decimal(123,
10, 0));
+ let expected = col("c3").lt(lit_decimal(12300, 18, 2));
assert_eq!(optimize_test(expr_lt, &schema), expected);
// c3 < Decimal(1230,10,3) -> c3 < CAST(DECIMAL(1230,10,3) AS
DECIMAL(18,2)) -> c3 < DECIMAL(123,18,2)
- let expr_lt = cast(col("c3"), DataType::Decimal128(10, 3))
- .lt(lit(ScalarValue::Decimal128(Some(1230), 10, 3)));
- let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(123), 18,
2)));
+ let expr_lt =
+ cast(col("c3"), DataType::Decimal128(10, 3)).lt(lit_decimal(1230,
10, 3));
+ let expected = col("c3").lt(lit_decimal(123, 18, 2));
assert_eq!(optimize_test(expr_lt, &schema), expected);
// decimal to integer
// c1 < Decimal(12300, 10, 2) -> c1 < CAST(DECIMAL(12300,10,2) AS
INT32) -> c1 < INT32(123)
- let expr_lt = cast(col("c1"), DataType::Decimal128(10, 2))
- .lt(lit(ScalarValue::Decimal128(Some(12300), 10, 2)));
- let expected = col("c1").lt(lit(ScalarValue::Int32(Some(123))));
+ let expr_lt =
+ cast(col("c1"), DataType::Decimal128(10, 2)).lt(lit_decimal(12300,
10, 2));
+ let expected = col("c1").lt(lit(123i32));
assert_eq!(optimize_test(expr_lt, &schema), expected);
}
@@ -501,42 +491,26 @@ mod tests {
let schema = expr_test_schema();
// internal left type is not supported
// FLOAT32(C5) in ...
- let expr_lt = cast(col("c5"), DataType::Int64).in_list(
- vec![
- lit(ScalarValue::Int64(Some(12))),
- lit(ScalarValue::Int64(Some(12))),
- ],
- false,
- );
+ let expr_lt =
+ cast(col("c5"), DataType::Int64).in_list(vec![lit(12i64),
lit(12i64)], false);
assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
// cast(INT32(C1), Float32) in (FLOAT32(1.23), Float32(12),
Float32(12))
- let expr_lt = cast(col("c1"), DataType::Float32).in_list(
- vec![
- lit(ScalarValue::Float32(Some(12.0))),
- lit(ScalarValue::Float32(Some(12.0))),
- lit(ScalarValue::Float32(Some(1.23))),
- ],
- false,
- );
+ let expr_lt = cast(col("c1"), DataType::Float32)
+ .in_list(vec![lit(12.0f32), lit(12.0f32), lit(1.23f32)], false);
assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
// INT32(C1) in (INT64(99999999999), INT64(12))
- let expr_lt = cast(col("c1"), DataType::Int64).in_list(
- vec![
- lit(ScalarValue::Int32(Some(12))),
- lit(ScalarValue::Int64(Some(99999999999))),
- ],
- false,
- );
+ let expr_lt = cast(col("c1"), DataType::Int64)
+ .in_list(vec![lit(12i32), lit(99999999999i64)], false);
assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
// DECIMAL(C3) in (INT64(12), INT32(12), DECIMAL(128,12,3))
let expr_lt = cast(col("c3"), DataType::Decimal128(12, 3)).in_list(
vec![
- lit(ScalarValue::Decimal128(Some(12), 12, 3)),
- lit(ScalarValue::Decimal128(Some(12), 12, 3)),
- lit(ScalarValue::Decimal128(Some(128), 12, 3)),
+ lit_decimal(12, 12, 3),
+ lit_decimal(12, 12, 3),
+ lit_decimal(128, 12, 3),
],
false,
);
@@ -547,36 +521,14 @@ mod tests {
fn test_unwrap_list_cast_comparison() {
let schema = expr_test_schema();
// INT32(C1) IN (INT32(12),INT64(24)) -> INT32(C1) IN
(INT32(12),INT32(24))
- let expr_lt = cast(col("c1"), DataType::Int64).in_list(
- vec![
- lit(ScalarValue::Int64(Some(12))),
- lit(ScalarValue::Int64(Some(24))),
- ],
- false,
- );
- let expected = col("c1").in_list(
- vec![
- lit(ScalarValue::Int32(Some(12))),
- lit(ScalarValue::Int32(Some(24))),
- ],
- false,
- );
+ let expr_lt =
+ cast(col("c1"), DataType::Int64).in_list(vec![lit(12i64),
lit(24i64)], false);
+ let expected = col("c1").in_list(vec![lit(12i32), lit(24i32)], false);
assert_eq!(optimize_test(expr_lt, &schema), expected);
// INT32(C2) IN (INT64(NULL),INT64(24)) -> INT32(C1) IN
(INT32(12),INT32(24))
- let expr_lt = cast(col("c2"), DataType::Int32).in_list(
- vec![
- lit(ScalarValue::Int32(None)),
- lit(ScalarValue::Int32(Some(14))),
- ],
- false,
- );
- let expected = col("c2").in_list(
- vec![
- lit(ScalarValue::Int64(None)),
- lit(ScalarValue::Int64(Some(14))),
- ],
- false,
- );
+ let expr_lt =
+ cast(col("c2"), DataType::Int32).in_list(vec![null_i32(),
lit(14i32)], false);
+ let expected = col("c2").in_list(vec![null_i64(), lit(14i64)], false);
assert_eq!(optimize_test(expr_lt, &schema), expected);
@@ -584,39 +536,28 @@ mod tests {
// c3 is decimal(18,2)
let expr_lt = cast(col("c3"), DataType::Decimal128(19, 3)).in_list(
vec![
- lit(ScalarValue::Decimal128(Some(12000), 19, 3)),
- lit(ScalarValue::Decimal128(Some(24000), 19, 3)),
- lit(ScalarValue::Decimal128(Some(1280), 19, 3)),
- lit(ScalarValue::Decimal128(Some(1240), 19, 3)),
+ lit_decimal(12000, 19, 3),
+ lit_decimal(24000, 19, 3),
+ lit_decimal(1280, 19, 3),
+ lit_decimal(1240, 19, 3),
],
false,
);
let expected = col("c3").in_list(
vec![
- lit(ScalarValue::Decimal128(Some(1200), 18, 2)),
- lit(ScalarValue::Decimal128(Some(2400), 18, 2)),
- lit(ScalarValue::Decimal128(Some(128), 18, 2)),
- lit(ScalarValue::Decimal128(Some(124), 18, 2)),
+ lit_decimal(1200, 18, 2),
+ lit_decimal(2400, 18, 2),
+ lit_decimal(128, 18, 2),
+ lit_decimal(124, 18, 2),
],
false,
);
assert_eq!(optimize_test(expr_lt, &schema), expected);
// cast(INT32(12), INT64) IN (.....)
- let expr_lt = cast(lit(ScalarValue::Int32(Some(12))),
DataType::Int64).in_list(
- vec![
- lit(ScalarValue::Int64(Some(13))),
- lit(ScalarValue::Int64(Some(12))),
- ],
- false,
- );
- let expected = lit(ScalarValue::Int32(Some(12))).in_list(
- vec![
- lit(ScalarValue::Int32(Some(13))),
- lit(ScalarValue::Int32(Some(12))),
- ],
- false,
- );
+ let expr_lt = cast(lit(12i32), DataType::Int64)
+ .in_list(vec![lit(13i64), lit(12i64)], false);
+ let expected = lit(12i32).in_list(vec![lit(13i32), lit(12i32)], false);
assert_eq!(optimize_test(expr_lt, &schema), expected);
}
@@ -625,10 +566,8 @@ mod tests {
let schema = expr_test_schema();
// c1 < INT64(16) -> c1 < cast(INT32(16))
// the 16 is within the range of MAX(int32) and MIN(int32), we can
cast the 16 to int32(16)
- let expr_lt = cast(col("c1"), DataType::Int64)
- .lt(lit(ScalarValue::Int64(Some(16))))
- .alias("x");
- let expected =
col("c1").lt(lit(ScalarValue::Int32(Some(16)))).alias("x");
+ let expr_lt = cast(col("c1"),
DataType::Int64).lt(lit(16i64)).alias("x");
+ let expected = col("c1").lt(lit(16i32)).alias("x");
assert_eq!(optimize_test(expr_lt, &schema), expected);
}
@@ -637,12 +576,12 @@ mod tests {
let schema = expr_test_schema();
// c1 < INT64(16) OR c1 > INT64(32) -> c1 < INT32(16) OR c1 > INT32(32)
// the 16 and 32 are within the range of MAX(int32) and MIN(int32), we
can cast them to int32
- let expr_lt = cast(col("c1"), DataType::Int64)
- .lt(lit(ScalarValue::Int64(Some(16))))
- .or(cast(col("c1"),
DataType::Int64).gt(lit(ScalarValue::Int64(Some(32)))));
- let expected = col("c1")
- .lt(lit(ScalarValue::Int32(Some(16))))
- .or(col("c1").gt(lit(ScalarValue::Int32(Some(32)))));
+ let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(16i64)).or(cast(
+ col("c1"),
+ DataType::Int64,
+ )
+ .gt(lit(32i64)));
+ let expected = col("c1").lt(lit(16i32)).or(col("c1").gt(lit(32i32)));
assert_eq!(optimize_test(expr_lt, &schema), expected);
}
@@ -668,4 +607,24 @@ mod tests {
.unwrap(),
)
}
+
+ fn null_i8() -> Expr {
+ lit(ScalarValue::Int8(None))
+ }
+
+ fn null_i32() -> Expr {
+ lit(ScalarValue::Int32(None))
+ }
+
+ fn null_i64() -> Expr {
+ lit(ScalarValue::Int64(None))
+ }
+
+ fn lit_decimal(value: i128, precision: u8, scale: u8) -> Expr {
+ lit(ScalarValue::Decimal128(Some(value), precision, scale))
+ }
+
+ fn null_decimal(precision: u8, scale: u8) -> Expr {
+ lit(ScalarValue::Decimal128(None, precision, scale))
+ }
}