This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new d076ab3b2 Bug/union wrong casting (#5342)
d076ab3b2 is described below
commit d076ab3b28ea0dffe65577dfa67dc7130c6006ce
Author: Berkay Şahin <[email protected]>
AuthorDate: Wed Mar 1 01:49:25 2023 +0300
Bug/union wrong casting (#5342)
* Fix union wrong casting
* Test format modified
* retry tests
* Signed integers never lose sign info
* Add comments explaining the match patterns
---------
Co-authored-by: Mehmet Ozan Kabak <[email protected]>
---
datafusion/core/tests/sql/union.rs | 46 +++++++++++++++++++++++++++++
datafusion/expr/src/type_coercion/binary.rs | 31 ++++++++++++++++---
2 files changed, 73 insertions(+), 4 deletions(-)
diff --git a/datafusion/core/tests/sql/union.rs
b/datafusion/core/tests/sql/union.rs
index 4cf908aa8..5d16c1dc9 100644
--- a/datafusion/core/tests/sql/union.rs
+++ b/datafusion/core/tests/sql/union.rs
@@ -94,3 +94,49 @@ async fn union_with_type_coercion() -> Result<()> {
);
Ok(())
}
+
+#[tokio::test]
+async fn test_union_upcast_types() -> Result<()> {
+ let config = SessionConfig::new()
+ .with_repartition_windows(false)
+ .with_target_partitions(1);
+ let ctx = SessionContext::with_config(config);
+ register_aggregate_csv(&ctx).await?;
+ let sql = "SELECT c1, c9 FROM aggregate_test_100
+ UNION ALL
+ SELECT c1, c3 FROM aggregate_test_100
+ ORDER BY c9 DESC LIMIT 5";
+ let msg = format!("Creating logical plan for '{sql}'");
+ let dataframe = ctx.sql(sql).await.expect(&msg);
+
+ let expected_logical_plan = vec![
+ "Limit: skip=0, fetch=5 [c1:Utf8, c9:Int64]",
+ " Sort: c9 DESC NULLS FIRST [c1:Utf8, c9:Int64]",
+ " Union [c1:Utf8, c9:Int64]",
+ " Projection: aggregate_test_100.c1, CAST(aggregate_test_100.c9
AS Int64) AS c9 [c1:Utf8, c9:Int64]",
+ " TableScan: aggregate_test_100 [c1:Utf8, c2:UInt32, c3:Int8,
c4:Int16, c5:Int32, c6:Int64, c7:UInt8, c8:UInt16, c9:UInt32, c10:UInt64,
c11:Float32, c12:Float64, c13:Utf8]",
+ " Projection: aggregate_test_100.c1, CAST(aggregate_test_100.c3
AS Int64) AS c9 [c1:Utf8, c9:Int64]",
+ " TableScan: aggregate_test_100 [c1:Utf8, c2:UInt32, c3:Int8,
c4:Int16, c5:Int32, c6:Int64, c7:UInt8, c8:UInt16, c9:UInt32, c10:UInt64,
c11:Float32, c12:Float64, c13:Utf8]",
+ ];
+ let formatted_logical_plan =
+ dataframe.logical_plan().display_indent_schema().to_string();
+ let actual_logical_plan: Vec<&str> =
formatted_logical_plan.trim().lines().collect();
+ assert_eq!(expected_logical_plan, actual_logical_plan,
"\n\nexpected:\n\n{expected_logical_plan:#?}\nactual:\n\n{actual_logical_plan:#?}\n\n");
+
+ let actual = execute_to_batches(&ctx, sql).await;
+
+ let expected = vec![
+ "+----+------------+",
+ "| c1 | c9 |",
+ "+----+------------+",
+ "| c | 4268716378 |",
+ "| e | 4229654142 |",
+ "| d | 4216440507 |",
+ "| e | 4144173353 |",
+ "| b | 4076864659 |",
+ "+----+------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ Ok(())
+}
diff --git a/datafusion/expr/src/type_coercion/binary.rs
b/datafusion/expr/src/type_coercion/binary.rs
index 3a2ca9c79..ba24f4a90 100644
--- a/datafusion/expr/src/type_coercion/binary.rs
+++ b/datafusion/expr/src/type_coercion/binary.rs
@@ -238,13 +238,36 @@ fn comparison_binary_numeric_coercion(
(_, Decimal128(_, _)) => get_comparison_common_decimal_type(rhs_type,
lhs_type),
(Float64, _) | (_, Float64) => Some(Float64),
(_, Float32) | (Float32, _) => Some(Float32),
- (Int64, _) | (_, Int64) => Some(Int64),
- (Int32, _) | (_, Int32) => Some(Int32),
- (Int16, _) | (_, Int16) => Some(Int16),
- (Int8, _) | (_, Int8) => Some(Int8),
+ // The following match arms encode the following logic: Given the two
+ // integral types, we choose the narrowest possible integral type that
+ // accommodates all values of both types. Note that some information
+ // loss is inevitable when we have a signed type and a `UInt64`, in
+ // which case we use `Int64`;i.e. the widest signed integral type.
+ (Int64, _)
+ | (_, Int64)
+ | (UInt64, Int8)
+ | (Int8, UInt64)
+ | (UInt64, Int16)
+ | (Int16, UInt64)
+ | (UInt64, Int32)
+ | (Int32, UInt64)
+ | (UInt32, Int8)
+ | (Int8, UInt32)
+ | (UInt32, Int16)
+ | (Int16, UInt32)
+ | (UInt32, Int32)
+ | (Int32, UInt32) => Some(Int64),
(UInt64, _) | (_, UInt64) => Some(UInt64),
+ (Int32, _)
+ | (_, Int32)
+ | (UInt16, Int16)
+ | (Int16, UInt16)
+ | (UInt16, Int8)
+ | (Int8, UInt16) => Some(Int32),
(UInt32, _) | (_, UInt32) => Some(UInt32),
+ (Int16, _) | (_, Int16) | (Int8, UInt8) | (UInt8, Int8) => Some(Int16),
(UInt16, _) | (_, UInt16) => Some(UInt16),
+ (Int8, _) | (_, Int8) => Some(Int8),
(UInt8, _) | (_, UInt8) => Some(UInt8),
_ => None,
}