This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new d076ab3b2 Bug/union wrong casting (#5342)
d076ab3b2 is described below

commit d076ab3b28ea0dffe65577dfa67dc7130c6006ce
Author: Berkay Şahin <[email protected]>
AuthorDate: Wed Mar 1 01:49:25 2023 +0300

    Bug/union wrong casting (#5342)
    
    * Fix union wrong casting
    
    * Test format modified
    
    * retry tests
    
    * Signed integers never lose sign info
    
    * Add comments explaining the match patterns
    
    ---------
    
    Co-authored-by: Mehmet Ozan Kabak <[email protected]>
---
 datafusion/core/tests/sql/union.rs          | 46 +++++++++++++++++++++++++++++
 datafusion/expr/src/type_coercion/binary.rs | 31 ++++++++++++++++---
 2 files changed, 73 insertions(+), 4 deletions(-)

diff --git a/datafusion/core/tests/sql/union.rs 
b/datafusion/core/tests/sql/union.rs
index 4cf908aa8..5d16c1dc9 100644
--- a/datafusion/core/tests/sql/union.rs
+++ b/datafusion/core/tests/sql/union.rs
@@ -94,3 +94,49 @@ async fn union_with_type_coercion() -> Result<()> {
     );
     Ok(())
 }
+
+#[tokio::test]
+async fn test_union_upcast_types() -> Result<()> {
+    let config = SessionConfig::new()
+        .with_repartition_windows(false)
+        .with_target_partitions(1);
+    let ctx = SessionContext::with_config(config);
+    register_aggregate_csv(&ctx).await?;
+    let sql = "SELECT c1, c9 FROM aggregate_test_100 
+                     UNION ALL 
+                     SELECT c1, c3 FROM aggregate_test_100 
+                     ORDER BY c9 DESC LIMIT 5";
+    let msg = format!("Creating logical plan for '{sql}'");
+    let dataframe = ctx.sql(sql).await.expect(&msg);
+
+    let expected_logical_plan = vec![
+        "Limit: skip=0, fetch=5 [c1:Utf8, c9:Int64]",
+        "  Sort: c9 DESC NULLS FIRST [c1:Utf8, c9:Int64]",
+        "    Union [c1:Utf8, c9:Int64]",
+        "      Projection: aggregate_test_100.c1, CAST(aggregate_test_100.c9 
AS Int64) AS c9 [c1:Utf8, c9:Int64]",
+        "        TableScan: aggregate_test_100 [c1:Utf8, c2:UInt32, c3:Int8, 
c4:Int16, c5:Int32, c6:Int64, c7:UInt8, c8:UInt16, c9:UInt32, c10:UInt64, 
c11:Float32, c12:Float64, c13:Utf8]",
+        "      Projection: aggregate_test_100.c1, CAST(aggregate_test_100.c3 
AS Int64) AS c9 [c1:Utf8, c9:Int64]",
+        "        TableScan: aggregate_test_100 [c1:Utf8, c2:UInt32, c3:Int8, 
c4:Int16, c5:Int32, c6:Int64, c7:UInt8, c8:UInt16, c9:UInt32, c10:UInt64, 
c11:Float32, c12:Float64, c13:Utf8]",
+    ];
+    let formatted_logical_plan =
+        dataframe.logical_plan().display_indent_schema().to_string();
+    let actual_logical_plan: Vec<&str> = 
formatted_logical_plan.trim().lines().collect();
+    assert_eq!(expected_logical_plan, actual_logical_plan, 
"\n\nexpected:\n\n{expected_logical_plan:#?}\nactual:\n\n{actual_logical_plan:#?}\n\n");
+
+    let actual = execute_to_batches(&ctx, sql).await;
+
+    let expected = vec![
+        "+----+------------+",
+        "| c1 | c9         |",
+        "+----+------------+",
+        "| c  | 4268716378 |",
+        "| e  | 4229654142 |",
+        "| d  | 4216440507 |",
+        "| e  | 4144173353 |",
+        "| b  | 4076864659 |",
+        "+----+------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
+    Ok(())
+}
diff --git a/datafusion/expr/src/type_coercion/binary.rs 
b/datafusion/expr/src/type_coercion/binary.rs
index 3a2ca9c79..ba24f4a90 100644
--- a/datafusion/expr/src/type_coercion/binary.rs
+++ b/datafusion/expr/src/type_coercion/binary.rs
@@ -238,13 +238,36 @@ fn comparison_binary_numeric_coercion(
         (_, Decimal128(_, _)) => get_comparison_common_decimal_type(rhs_type, 
lhs_type),
         (Float64, _) | (_, Float64) => Some(Float64),
         (_, Float32) | (Float32, _) => Some(Float32),
-        (Int64, _) | (_, Int64) => Some(Int64),
-        (Int32, _) | (_, Int32) => Some(Int32),
-        (Int16, _) | (_, Int16) => Some(Int16),
-        (Int8, _) | (_, Int8) => Some(Int8),
+        // The following match arms encode the following logic: Given the two
+        // integral types, we choose the narrowest possible integral type that
+        // accommodates all values of both types. Note that some information
+        // loss is inevitable when we have a signed type and a `UInt64`, in
+        // which case we use `Int64`;i.e. the widest signed integral type.
+        (Int64, _)
+        | (_, Int64)
+        | (UInt64, Int8)
+        | (Int8, UInt64)
+        | (UInt64, Int16)
+        | (Int16, UInt64)
+        | (UInt64, Int32)
+        | (Int32, UInt64)
+        | (UInt32, Int8)
+        | (Int8, UInt32)
+        | (UInt32, Int16)
+        | (Int16, UInt32)
+        | (UInt32, Int32)
+        | (Int32, UInt32) => Some(Int64),
         (UInt64, _) | (_, UInt64) => Some(UInt64),
+        (Int32, _)
+        | (_, Int32)
+        | (UInt16, Int16)
+        | (Int16, UInt16)
+        | (UInt16, Int8)
+        | (Int8, UInt16) => Some(Int32),
         (UInt32, _) | (_, UInt32) => Some(UInt32),
+        (Int16, _) | (_, Int16) | (Int8, UInt8) | (UInt8, Int8) => Some(Int16),
         (UInt16, _) | (_, UInt16) => Some(UInt16),
+        (Int8, _) | (_, Int8) => Some(Int8),
         (UInt8, _) | (_, UInt8) => Some(UInt8),
         _ => None,
     }

Reply via email to