This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 1e0c7607d5 feat: unwrap casts of string and dictionary columns (#10323)
1e0c7607d5 is described below
commit 1e0c7607d51faf698eef43a07de22a9578d83712
Author: Adam Curtis <[email protected]>
AuthorDate: Wed May 1 15:57:49 2024 -0400
feat: unwrap casts of string and dictionary columns (#10323)
* feat: unwrap casts of string and dictionary columns
* feat: allow unwrapping casts for any dictionary type
* docs: fix
* add LargeUtf8
* add explain test for integer cast
* remove unnecessary equality check
this should prevent returning Transformed in cases where nothing was
changed
* update comments
---
.../optimizer/src/unwrap_cast_in_comparison.rs | 233 ++++++++++++++++-----
datafusion/sqllogictest/test_files/dictionary.slt | 45 ++++
2 files changed, 229 insertions(+), 49 deletions(-)
diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
index 138769674d..293a694d68 100644
--- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
+++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
@@ -152,8 +152,8 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter {
let Ok(right_type) = right.get_type(&self.schema) else {
return Ok(Transformed::no(expr));
};
- is_support_data_type(&left_type)
- && is_support_data_type(&right_type)
+ is_supported_type(&left_type)
+ && is_supported_type(&right_type)
&& is_comparison_op(op)
} =>
{
@@ -172,7 +172,7 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter {
let Ok(expr_type) = right_expr.get_type(&self.schema)
else {
return Ok(Transformed::no(expr));
};
- let Ok(Some(value)) =
+ let Some(value) =
try_cast_literal_to_type(left_lit_value,
&expr_type)
else {
return Ok(Transformed::no(expr));
@@ -196,7 +196,7 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter {
let Ok(expr_type) = left_expr.get_type(&self.schema)
else {
return Ok(Transformed::no(expr));
};
- let Ok(Some(value)) =
+ let Some(value) =
try_cast_literal_to_type(right_lit_value,
&expr_type)
else {
return Ok(Transformed::no(expr));
@@ -226,14 +226,14 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter {
let Ok(expr_type) = left_expr.get_type(&self.schema) else {
return Ok(Transformed::no(expr));
};
- if !is_support_data_type(&expr_type) {
+ if !is_supported_type(&expr_type) {
return Ok(Transformed::no(expr));
}
let Ok(right_exprs) = list
.iter()
.map(|right| {
let right_type = right.get_type(&self.schema)?;
- if !is_support_data_type(&right_type) {
+ if !is_supported_type(&right_type) {
internal_err!(
"The type of list expr {} is not supported",
&right_type
@@ -243,7 +243,7 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter {
Expr::Literal(right_lit_value) => {
// if the right_lit_value can be casted to the
type of internal_left_expr
// we need to unwrap the cast for
cast/try_cast expr, and add cast to the literal
- let Ok(Some(value)) =
try_cast_literal_to_type(right_lit_value, &expr_type) else {
+ let Some(value) =
try_cast_literal_to_type(right_lit_value, &expr_type) else {
internal_err!(
"Can't cast the list expr {:?} to type
{:?}",
right_lit_value, &expr_type
@@ -282,7 +282,15 @@ fn is_comparison_op(op: &Operator) -> bool {
)
}
-fn is_support_data_type(data_type: &DataType) -> bool {
+/// Returns true if [UnwrapCastExprRewriter] supports this data type
+fn is_supported_type(data_type: &DataType) -> bool {
+ is_supported_numeric_type(data_type)
+ || is_supported_string_type(data_type)
+ || is_supported_dictionary_type(data_type)
+}
+
+/// Returns true if [[UnwrapCastExprRewriter]] suppors this numeric type
+fn is_supported_numeric_type(data_type: &DataType) -> bool {
matches!(
data_type,
DataType::UInt8
@@ -298,19 +306,47 @@ fn is_support_data_type(data_type: &DataType) -> bool {
)
}
+/// Returns true if [UnwrapCastExprRewriter] supports casting this value as a
string
+fn is_supported_string_type(data_type: &DataType) -> bool {
+ matches!(data_type, DataType::Utf8 | DataType::LargeUtf8)
+}
+
+/// Returns true if [UnwrapCastExprRewriter] supports casting this value as a
dictionary
+fn is_supported_dictionary_type(data_type: &DataType) -> bool {
+ matches!(data_type,
+ DataType::Dictionary(_, inner) if is_supported_type(inner))
+}
+
+/// Convert a literal value from one data type to another
fn try_cast_literal_to_type(
lit_value: &ScalarValue,
target_type: &DataType,
-) -> Result<Option<ScalarValue>> {
+) -> Option<ScalarValue> {
let lit_data_type = lit_value.data_type();
- // the rule just support the signed numeric data type now
- if !is_support_data_type(&lit_data_type) ||
!is_support_data_type(target_type) {
- return Ok(None);
+ if !is_supported_type(&lit_data_type) || !is_supported_type(target_type) {
+ return None;
}
if lit_value.is_null() {
// null value can be cast to any type of null value
- return Ok(Some(ScalarValue::try_from(target_type)?));
+ return ScalarValue::try_from(target_type).ok();
+ }
+ try_cast_numeric_literal(lit_value, target_type)
+ .or_else(|| try_cast_string_literal(lit_value, target_type))
+ .or_else(|| try_cast_dictionary(lit_value, target_type))
+}
+
+/// Convert a numeric value from one numeric data type to another
+fn try_cast_numeric_literal(
+ lit_value: &ScalarValue,
+ target_type: &DataType,
+) -> Option<ScalarValue> {
+ let lit_data_type = lit_value.data_type();
+ if !is_supported_numeric_type(&lit_data_type)
+ || !is_supported_numeric_type(target_type)
+ {
+ return None;
}
+
let mul = match target_type {
DataType::UInt8
| DataType::UInt16
@@ -322,9 +358,7 @@ fn try_cast_literal_to_type(
| DataType::Int64 => 1_i128,
DataType::Timestamp(_, _) => 1_i128,
DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32),
- other_type => {
- return internal_err!("Error target data type {other_type:?}");
- }
+ _ => return None,
};
let (target_min, target_max) = match target_type {
DataType::UInt8 => (u8::MIN as i128, u8::MAX as i128),
@@ -343,9 +377,7 @@ fn try_cast_literal_to_type(
MIN_DECIMAL_FOR_EACH_PRECISION[*precision as usize - 1],
MAX_DECIMAL_FOR_EACH_PRECISION[*precision as usize - 1],
),
- other_type => {
- return internal_err!("Error target data type {other_type:?}");
- }
+ _ => return None,
};
let lit_value_target_type = match lit_value {
ScalarValue::Int8(Some(v)) => (*v as i128).checked_mul(mul),
@@ -379,13 +411,11 @@ fn try_cast_literal_to_type(
None
}
}
- other_value => {
- return internal_err!("Invalid literal value {other_value:?}");
- }
+ _ => None,
};
match lit_value_target_type {
- None => Ok(None),
+ None => None,
Some(value) => {
if value >= target_min && value <= target_max {
// the value casted from lit to the target type is in the
range of target type.
@@ -434,18 +464,60 @@ fn try_cast_literal_to_type(
DataType::Decimal128(p, s) => {
ScalarValue::Decimal128(Some(value), *p, *s)
}
- other_type => {
- return internal_err!("Error target data type
{other_type:?}");
+ _ => {
+ return None;
}
};
- Ok(Some(result_scalar))
+ Some(result_scalar)
} else {
- Ok(None)
+ None
}
}
}
}
+fn try_cast_string_literal(
+ lit_value: &ScalarValue,
+ target_type: &DataType,
+) -> Option<ScalarValue> {
+ let string_value = match lit_value {
+ ScalarValue::Utf8(s) | ScalarValue::LargeUtf8(s) => s.clone(),
+ _ => return None,
+ };
+ let scalar_value = match target_type {
+ DataType::Utf8 => ScalarValue::Utf8(string_value),
+ DataType::LargeUtf8 => ScalarValue::LargeUtf8(string_value),
+ _ => return None,
+ };
+ Some(scalar_value)
+}
+
+/// Attempt to cast to/from a dictionary type by wrapping/unwrapping the
dictionary
+fn try_cast_dictionary(
+ lit_value: &ScalarValue,
+ target_type: &DataType,
+) -> Option<ScalarValue> {
+ let lit_value_type = lit_value.data_type();
+ let result_scalar = match (lit_value, target_type) {
+ // Unwrap dictionary when inner type matches target type
+ (ScalarValue::Dictionary(_, inner_value), _)
+ if inner_value.data_type() == *target_type =>
+ {
+ (**inner_value).clone()
+ }
+ // Wrap type when target type is dictionary
+ (_, DataType::Dictionary(index_type, inner_type))
+ if **inner_type == lit_value_type =>
+ {
+ ScalarValue::Dictionary(index_type.clone(),
Box::new(lit_value.clone()))
+ }
+ _ => {
+ return None;
+ }
+ };
+ Some(result_scalar)
+}
+
/// Cast a timestamp value from one unit to another
fn cast_between_timestamp(from: DataType, to: DataType, value: i128) ->
Option<i64> {
let value = value as i64;
@@ -536,6 +608,35 @@ mod tests {
assert_eq!(optimize_test(expr_input, &schema), expected);
}
+ #[test]
+ fn test_unwrap_cast_comparison_string() {
+ let schema = expr_test_schema();
+ let dict = ScalarValue::Dictionary(
+ Box::new(DataType::Int32),
+ Box::new(ScalarValue::from("value")),
+ );
+
+ // cast(str1 as Dictionary<Int32, Utf8>) = arrow_cast('value',
'Dictionary<Int32, Utf8>') => str1 = Utf8('value1')
+ let expr_input = cast(col("str1"),
dict.data_type()).eq(lit(dict.clone()));
+ let expected = col("str1").eq(lit("value"));
+ assert_eq!(optimize_test(expr_input, &schema), expected);
+
+ // cast(tag as Utf8) = Utf8('value') => tag = arrow_cast('value',
'Dictionary<Int32, Utf8>')
+ let expr_input = cast(col("tag"), DataType::Utf8).eq(lit("value"));
+ let expected = col("tag").eq(lit(dict));
+ assert_eq!(optimize_test(expr_input, &schema), expected);
+
+ // cast(largestr as Dictionary<Int32, LargeUtf8>) =
arrow_cast('value', 'Dictionary<Int32, LargeUtf8>') => str1 =
LargeUtf8('value1')
+ let dict = ScalarValue::Dictionary(
+ Box::new(DataType::Int32),
+ Box::new(ScalarValue::LargeUtf8(Some("value".to_owned()))),
+ );
+ let expr_input = cast(col("largestr"), dict.data_type()).eq(lit(dict));
+ let expected =
+
col("largestr").eq(lit(ScalarValue::LargeUtf8(Some("value".to_owned()))));
+ assert_eq!(optimize_test(expr_input, &schema), expected);
+ }
+
#[test]
fn test_not_unwrap_cast_with_decimal_comparison() {
let schema = expr_test_schema();
@@ -746,6 +847,9 @@ mod tests {
Field::new("c6", DataType::UInt32, false),
Field::new("ts_nano_none", timestamp_nano_none_type(),
false),
Field::new("ts_nano_utf", timestamp_nano_utc_type(),
false),
+ Field::new("str1", DataType::Utf8, false),
+ Field::new("largestr", DataType::LargeUtf8, false),
+ Field::new("tag", dictionary_tag_type(), false),
]
.into(),
HashMap::new(),
@@ -793,6 +897,11 @@ mod tests {
DataType::Timestamp(TimeUnit::Nanosecond, utc)
}
+ // a dictonary type for storing string tags
+ fn dictionary_tag_type() -> DataType {
+ DataType::Dictionary(Box::new(DataType::Int32),
Box::new(DataType::Utf8))
+ }
+
#[test]
fn test_try_cast_to_type_nulls() {
// test that nulls can be cast to/from all integer types
@@ -807,6 +916,8 @@ mod tests {
ScalarValue::UInt64(None),
ScalarValue::Decimal128(None, 3, 0),
ScalarValue::Decimal128(None, 8, 2),
+ ScalarValue::Utf8(None),
+ ScalarValue::LargeUtf8(None),
];
for s1 in &scalars {
@@ -1061,18 +1172,17 @@ mod tests {
target_type: DataType,
expected_result: ExpectedCast,
) {
- let actual_result = try_cast_literal_to_type(&literal, &target_type);
+ let actual_value = try_cast_literal_to_type(&literal, &target_type);
println!("expect_cast: ");
println!(" {literal:?} --> {target_type:?}");
println!(" expected_result: {expected_result:?}");
- println!(" actual_result: {actual_result:?}");
+ println!(" actual_result: {actual_value:?}");
match expected_result {
ExpectedCast::Value(expected_value) => {
- let actual_value = actual_result
- .expect("Expected success but got error")
- .expect("Expected cast value but got None");
+ let actual_value =
+ actual_value.expect("Expected cast value but got None");
assert_eq!(actual_value, expected_value);
@@ -1094,7 +1204,7 @@ mod tests {
assert_eq!(
&expected_array, &cast_array,
- "Result of casing {literal:?} with arrow was\n
{cast_array:#?}\nbut expected\n{expected_array:#?}"
+ "Result of casting {literal:?} with arrow was\n
{cast_array:#?}\nbut expected\n{expected_array:#?}"
);
// Verify that for timestamp types the timezones are the same
@@ -1109,8 +1219,6 @@ mod tests {
}
}
ExpectedCast::NoValue => {
- let actual_value = actual_result.expect("Expected success but
got error");
-
assert!(
actual_value.is_none(),
"Expected no cast value, but got {actual_value:?}"
@@ -1126,7 +1234,6 @@ mod tests {
&ScalarValue::TimestampNanosecond(Some(123456), None),
&DataType::Timestamp(TimeUnit::Nanosecond, None),
)
- .unwrap()
.unwrap();
assert_eq!(
@@ -1139,7 +1246,6 @@ mod tests {
&ScalarValue::TimestampNanosecond(Some(123456), None),
&DataType::Timestamp(TimeUnit::Microsecond, None),
)
- .unwrap()
.unwrap();
assert_eq!(
@@ -1152,7 +1258,6 @@ mod tests {
&ScalarValue::TimestampNanosecond(Some(123456), None),
&DataType::Timestamp(TimeUnit::Millisecond, None),
)
- .unwrap()
.unwrap();
assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0),
None));
@@ -1162,7 +1267,6 @@ mod tests {
&ScalarValue::TimestampNanosecond(Some(123456), None),
&DataType::Timestamp(TimeUnit::Second, None),
)
- .unwrap()
.unwrap();
assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(0), None));
@@ -1172,7 +1276,6 @@ mod tests {
&ScalarValue::TimestampMicrosecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Nanosecond, None),
)
- .unwrap()
.unwrap();
assert_eq!(
@@ -1185,7 +1288,6 @@ mod tests {
&ScalarValue::TimestampMicrosecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Millisecond, None),
)
- .unwrap()
.unwrap();
assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0),
None));
@@ -1195,7 +1297,6 @@ mod tests {
&ScalarValue::TimestampMicrosecond(Some(123456789), None),
&DataType::Timestamp(TimeUnit::Second, None),
)
- .unwrap()
.unwrap();
assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123), None));
@@ -1204,7 +1305,6 @@ mod tests {
&ScalarValue::TimestampMillisecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Nanosecond, None),
)
- .unwrap()
.unwrap();
assert_eq!(
new_scalar,
@@ -1216,7 +1316,6 @@ mod tests {
&ScalarValue::TimestampMillisecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Microsecond, None),
)
- .unwrap()
.unwrap();
assert_eq!(
new_scalar,
@@ -1227,7 +1326,6 @@ mod tests {
&ScalarValue::TimestampMillisecond(Some(123456789), None),
&DataType::Timestamp(TimeUnit::Second, None),
)
- .unwrap()
.unwrap();
assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123456),
None));
@@ -1236,7 +1334,6 @@ mod tests {
&ScalarValue::TimestampSecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Nanosecond, None),
)
- .unwrap()
.unwrap();
assert_eq!(
new_scalar,
@@ -1248,7 +1345,6 @@ mod tests {
&ScalarValue::TimestampSecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Microsecond, None),
)
- .unwrap()
.unwrap();
assert_eq!(
new_scalar,
@@ -1260,7 +1356,6 @@ mod tests {
&ScalarValue::TimestampSecond(Some(123), None),
&DataType::Timestamp(TimeUnit::Millisecond, None),
)
- .unwrap()
.unwrap();
assert_eq!(
new_scalar,
@@ -1272,8 +1367,48 @@ mod tests {
&ScalarValue::TimestampSecond(Some(i64::MAX), None),
&DataType::Timestamp(TimeUnit::Millisecond, None),
)
- .unwrap()
.unwrap();
assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(None, None));
}
+
+ #[test]
+ fn test_try_cast_to_string_type() {
+ let scalars = vec![
+ ScalarValue::from("string"),
+ ScalarValue::LargeUtf8(Some("string".to_owned())),
+ ];
+
+ for s1 in &scalars {
+ for s2 in &scalars {
+ let expected_value = ExpectedCast::Value(s2.clone());
+
+ expect_cast(s1.clone(), s2.data_type(), expected_value);
+ }
+ }
+ }
+ #[test]
+ fn test_try_cast_to_dictionary_type() {
+ fn dictionary_type(t: DataType) -> DataType {
+ DataType::Dictionary(Box::new(DataType::Int32), Box::new(t))
+ }
+ fn dictionary_value(value: ScalarValue) -> ScalarValue {
+ ScalarValue::Dictionary(Box::new(DataType::Int32), Box::new(value))
+ }
+ let scalars = vec![
+ ScalarValue::from("string"),
+ ScalarValue::LargeUtf8(Some("string".to_owned())),
+ ];
+ for s in &scalars {
+ expect_cast(
+ s.clone(),
+ dictionary_type(s.data_type()),
+ ExpectedCast::Value(dictionary_value(s.clone())),
+ );
+ expect_cast(
+ dictionary_value(s.clone()),
+ s.data_type(),
+ ExpectedCast::Value(s.clone()),
+ )
+ }
+ }
}
diff --git a/datafusion/sqllogictest/test_files/dictionary.slt
b/datafusion/sqllogictest/test_files/dictionary.slt
index 06b7265026..7e45f5e444 100644
--- a/datafusion/sqllogictest/test_files/dictionary.slt
+++ b/datafusion/sqllogictest/test_files/dictionary.slt
@@ -386,3 +386,48 @@ drop table m3;
statement ok
drop table m3_source;
+
+
+## Test that filtering on dictionary columns coerces the filter value to the
dictionary type
+statement ok
+create table test as values
+ ('row1', arrow_cast('1', 'Dictionary(Int32, Utf8)')),
+ ('row2', arrow_cast('2', 'Dictionary(Int32, Utf8)')),
+ ('row3', arrow_cast('3', 'Dictionary(Int32, Utf8)'))
+;
+
+# query using an string '1' which must be coerced into a dictionary string
+query T?
+SELECT * from test where column2 = '1';
+----
+row1 1
+
+# filter should not have a cast on column2
+query TT
+explain SELECT * from test where column2 = '1';
+----
+logical_plan
+01)Filter: test.column2 = Dictionary(Int32, Utf8("1"))
+02)--TableScan: test projection=[column1, column2]
+physical_plan
+01)CoalesceBatchesExec: target_batch_size=8192
+02)--FilterExec: column2@1 = 1
+03)----MemoryExec: partitions=1, partition_sizes=[1]
+
+
+# Now query using an integer which must be coerced into a dictionary string
+query T?
+SELECT * from test where column2 = 1;
+----
+row1 1
+
+query TT
+explain SELECT * from test where column2 = 1;
+----
+logical_plan
+01)Filter: test.column2 = Dictionary(Int32, Utf8("1"))
+02)--TableScan: test projection=[column1, column2]
+physical_plan
+01)CoalesceBatchesExec: target_batch_size=8192
+02)--FilterExec: column2@1 = 1
+03)----MemoryExec: partitions=1, partition_sizes=[1]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]