This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new dc7e44db48 fix: Unconditionally wrap UNION BY NAME input nodes w/ 
`Projection` (#15242)
dc7e44db48 is described below

commit dc7e44db48a2c1a8203ba9725c877a7fd707cc0f
Author: Rohan Krishnaswamy <47869999+rkris...@users.noreply.github.com>
AuthorDate: Thu Mar 27 12:06:10 2025 -0700

    fix: Unconditionally wrap UNION BY NAME input nodes w/ `Projection` (#15242)
    
    * fix: Remove incorrect predicate to skip input wrapping when rewriting 
union inputs
    
    * chore: Add/update tests
    
    * fix: SQL integration tests
    
    * test: Add union all by name SLT tests
    
    * test: Add problematic union all by name SLT test
    
    * chore: styling nits
    
    * fix: Correct handling of nullability when field is not present in all 
inputs
    
    * chore: Update fixme comment
    
    * fix: handle ordering by order of inputs
---
 datafusion/expr/src/logical_plan/plan.rs           |  89 +++++-----
 datafusion/sql/tests/sql_integration.rs            |  27 +--
 .../sqllogictest/test_files/union_by_name.slt      | 196 +++++++++++++++++----
 3 files changed, 229 insertions(+), 83 deletions(-)

diff --git a/datafusion/expr/src/logical_plan/plan.rs 
b/datafusion/expr/src/logical_plan/plan.rs
index 641489b5d9..76b45d5d72 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -18,7 +18,7 @@
 //! Logical plan types
 
 use std::cmp::Ordering;
-use std::collections::{BTreeMap, HashMap, HashSet};
+use std::collections::{HashMap, HashSet};
 use std::fmt::{self, Debug, Display, Formatter};
 use std::hash::{Hash, Hasher};
 use std::str::FromStr;
@@ -2681,24 +2681,16 @@ impl Union {
         Ok(Union { inputs, schema })
     }
 
-    /// When constructing a `UNION BY NAME`, we may need to wrap inputs
+    /// When constructing a `UNION BY NAME`, we need to wrap inputs
     /// in an additional `Projection` to account for absence of columns
-    /// in input schemas.
+    /// in input schemas or differing projection orders.
     fn rewrite_inputs_from_schema(
-        schema: &DFSchema,
+        schema: &Arc<DFSchema>,
         inputs: Vec<Arc<LogicalPlan>>,
     ) -> Result<Vec<Arc<LogicalPlan>>> {
         let schema_width = schema.iter().count();
         let mut wrapped_inputs = Vec::with_capacity(inputs.len());
         for input in inputs {
-            // If the input plan's schema contains the same number of fields
-            // as the derived schema, then it does not to be wrapped in an
-            // additional `Projection`.
-            if input.schema().iter().count() == schema_width {
-                wrapped_inputs.push(input);
-                continue;
-            }
-
             // Any columns that exist within the derived schema but do not 
exist
             // within an input's schema should be replaced with `NULL` aliased
             // to the appropriate column in the derived schema.
@@ -2713,9 +2705,9 @@ impl Union {
                     
expr.push(Expr::Literal(ScalarValue::Null).alias(column.name()));
                 }
             }
-            
wrapped_inputs.push(Arc::new(LogicalPlan::Projection(Projection::try_new(
-                expr, input,
-            )?)));
+            wrapped_inputs.push(Arc::new(LogicalPlan::Projection(
+                Projection::try_new_with_schema(expr, input, 
Arc::clone(schema))?,
+            )));
         }
 
         Ok(wrapped_inputs)
@@ -2749,45 +2741,60 @@ impl Union {
         inputs: &[Arc<LogicalPlan>],
         loose_types: bool,
     ) -> Result<DFSchemaRef> {
-        type FieldData<'a> = (&'a DataType, bool, Vec<&'a HashMap<String, 
String>>);
-        // Prefer `BTreeMap` as it produces items in order by key when 
iterated over
-        let mut cols: BTreeMap<&str, FieldData> = BTreeMap::new();
+        type FieldData<'a> =
+            (&'a DataType, bool, Vec<&'a HashMap<String, String>>, usize);
+        let mut cols: Vec<(&str, FieldData)> = Vec::new();
         for input in inputs.iter() {
             for field in input.schema().fields() {
-                match cols.entry(field.name()) {
-                    std::collections::btree_map::Entry::Occupied(mut occupied) 
=> {
-                        let (data_type, is_nullable, metadata) = 
occupied.get_mut();
-                        if !loose_types && *data_type != field.data_type() {
-                            return plan_err!(
-                                "Found different types for field {}",
-                                field.name()
-                            );
-                        }
-
-                        metadata.push(field.metadata());
-                        // If the field is nullable in any one of the inputs,
-                        // then the field in the final schema is also nullable.
-                        *is_nullable |= field.is_nullable();
+                if let Some((_, (data_type, is_nullable, metadata, 
occurrences))) =
+                    cols.iter_mut().find(|(name, _)| name == field.name())
+                {
+                    if !loose_types && *data_type != field.data_type() {
+                        return plan_err!(
+                            "Found different types for field {}",
+                            field.name()
+                        );
                     }
-                    std::collections::btree_map::Entry::Vacant(vacant) => {
-                        vacant.insert((
+
+                    metadata.push(field.metadata());
+                    // If the field is nullable in any one of the inputs,
+                    // then the field in the final schema is also nullable.
+                    *is_nullable |= field.is_nullable();
+                    *occurrences += 1;
+                } else {
+                    cols.push((
+                        field.name(),
+                        (
                             field.data_type(),
                             field.is_nullable(),
                             vec![field.metadata()],
-                        ));
-                    }
+                            1,
+                        ),
+                    ));
                 }
             }
         }
 
         let union_fields = cols
             .into_iter()
-            .map(|(name, (data_type, is_nullable, unmerged_metadata))| {
-                let mut field = Field::new(name, data_type.clone(), 
is_nullable);
-                field.set_metadata(intersect_maps(unmerged_metadata));
+            .map(
+                |(name, (data_type, is_nullable, unmerged_metadata, 
occurrences))| {
+                    // If the final number of occurrences of the field is less
+                    // than the number of inputs (i.e. the field is missing 
from
+                    // one or more inputs), then it must be treated as 
nullable.
+                    let final_is_nullable = if occurrences == inputs.len() {
+                        is_nullable
+                    } else {
+                        true
+                    };
 
-                (None, Arc::new(field))
-            })
+                    let mut field =
+                        Field::new(name, data_type.clone(), final_is_nullable);
+                    field.set_metadata(intersect_maps(unmerged_metadata));
+
+                    (None, Arc::new(field))
+                },
+            )
             .collect::<Vec<(Option<TableReference>, _)>>();
 
         let union_schema_metadata =
diff --git a/datafusion/sql/tests/sql_integration.rs 
b/datafusion/sql/tests/sql_integration.rs
index 2939e965cd..866c08ed02 100644
--- a/datafusion/sql/tests/sql_integration.rs
+++ b/datafusion/sql/tests/sql_integration.rs
@@ -1898,11 +1898,12 @@ fn union_by_name_different_columns() {
     let expected = "\
         Distinct:\
         \n  Union\
-        \n    Projection: NULL AS Int64(1), order_id\
+        \n    Projection: order_id, NULL AS Int64(1)\
         \n      Projection: orders.order_id\
         \n        TableScan: orders\
-        \n    Projection: orders.order_id, Int64(1)\
-        \n      TableScan: orders";
+        \n    Projection: order_id, Int64(1)\
+        \n      Projection: orders.order_id, Int64(1)\
+        \n        TableScan: orders";
     quick_test(sql, expected);
 }
 
@@ -1936,22 +1937,26 @@ fn union_all_by_name_different_columns() {
         "SELECT order_id from orders UNION ALL BY NAME SELECT order_id, 1 FROM 
orders";
     let expected = "\
         Union\
-        \n  Projection: NULL AS Int64(1), order_id\
+        \n  Projection: order_id, NULL AS Int64(1)\
         \n    Projection: orders.order_id\
         \n      TableScan: orders\
-        \n  Projection: orders.order_id, Int64(1)\
-        \n    TableScan: orders";
+        \n  Projection: order_id, Int64(1)\
+        \n    Projection: orders.order_id, Int64(1)\
+        \n      TableScan: orders";
     quick_test(sql, expected);
 }
 
 #[test]
 fn union_all_by_name_same_column_names() {
     let sql = "SELECT order_id from orders UNION ALL BY NAME SELECT order_id 
FROM orders";
-    let expected = "Union\
-            \n  Projection: orders.order_id\
-            \n    TableScan: orders\
-            \n  Projection: orders.order_id\
-            \n    TableScan: orders";
+    let expected = "\
+        Union\
+        \n  Projection: order_id\
+        \n    Projection: orders.order_id\
+        \n      TableScan: orders\
+        \n  Projection: order_id\
+        \n    Projection: orders.order_id\
+        \n      TableScan: orders";
     quick_test(sql, expected);
 }
 
diff --git a/datafusion/sqllogictest/test_files/union_by_name.slt 
b/datafusion/sqllogictest/test_files/union_by_name.slt
index 3844dba680..9572e6efc3 100644
--- a/datafusion/sqllogictest/test_files/union_by_name.slt
+++ b/datafusion/sqllogictest/test_files/union_by_name.slt
@@ -88,38 +88,38 @@ SELECT x FROM t1 UNION ALL BY NAME SELECT x FROM t1 ORDER 
BY x;
 query II
 (SELECT x FROM t1 UNION ALL SELECT x FROM t1) UNION BY NAME SELECT 5 ORDER BY 
x;
 ----
-NULL 1
-NULL 3
-5 NULL
+1 NULL
+3 NULL
+NULL 5
 
 query II
 (SELECT x FROM t1 UNION ALL SELECT x FROM t1) UNION ALL BY NAME SELECT 5 ORDER 
BY x;
 ----
-NULL 1
-NULL 1
-NULL 3
-NULL 3
-NULL 3
-NULL 3
-5 NULL
+1 NULL
+1 NULL
+3 NULL
+3 NULL
+3 NULL
+3 NULL
+NULL 5
 
 query II
 (SELECT x FROM t1 UNION ALL SELECT y FROM t1) UNION BY NAME SELECT 5 ORDER BY 
x;
 ----
-NULL 1
-NULL 3
-5 NULL
+1 NULL
+3 NULL
+NULL 5
 
 query II
 (SELECT x FROM t1 UNION ALL SELECT y FROM t1) UNION ALL BY NAME SELECT 5 ORDER 
BY x;
 ----
-NULL 1
-NULL 1
-NULL 3
-NULL 3
-NULL 3
-NULL 3
-5 NULL
+1 NULL
+1 NULL
+3 NULL
+3 NULL
+3 NULL
+3 NULL
+NULL 5
 
 
 # Ambiguous name
@@ -152,22 +152,22 @@ NULL 4
 # Limit
 
 query III
-SELECT 1 UNION BY NAME SELECT * FROM unnest(range(2, 100)) UNION BY NAME 
SELECT 999 ORDER BY 3, 1 LIMIT 5;
+SELECT 1 UNION BY NAME SELECT * FROM unnest(range(2, 100)) UNION BY NAME 
SELECT 999 ORDER BY 3, 1, 2 LIMIT 5;
 ----
-NULL NULL 2
-NULL NULL 3
-NULL NULL 4
-NULL NULL 5
-NULL NULL 6
+NULL NULL 999
+1 NULL NULL
+NULL 2 NULL
+NULL 3 NULL
+NULL 4 NULL
 
 query III
 SELECT 1 UNION ALL BY NAME SELECT * FROM unnest(range(2, 100)) UNION ALL BY 
NAME SELECT 999 ORDER BY 3, 1 LIMIT 5;
 ----
-NULL NULL 2
-NULL NULL 3
-NULL NULL 4
-NULL NULL 5
-NULL NULL 6
+NULL NULL 999
+1 NULL NULL
+NULL 2 NULL
+NULL 3 NULL
+NULL 4 NULL
 
 # Order by
 
@@ -287,3 +287,137 @@ SELECT '0' as c UNION ALL BY NAME SELECT 0 as c;
 ----
 0
 0
+
+# Regression tests for https://github.com/apache/datafusion/issues/15236
+# Ensure that the correct output is produced even if the width of an input 
node's
+# schema is the same as the resulting schema width after the union is applied.
+
+statement ok
+create table t3 (x varchar(255), y varchar(255), z varchar(255));
+
+statement ok
+create table t4 (x varchar(255), y varchar(255), z varchar(255));
+
+statement ok
+insert into t3 values ('a', 'b', 'c');
+
+statement ok
+insert into t4 values ('a', 'b', 'c');
+
+query TTTT rowsort
+select t3.x, t3.y, t3.z from t3 union by name select t3.z, t3.y, t3.x, 'd' as 
zz from t3;
+----
+a b c NULL
+a b c d
+
+query TTTT rowsort
+select t3.x, t3.y, t3.z from t3 union by name select t4.z, t4.y, t4.x, 'd' as 
zz from t4;
+----
+a b c NULL
+a b c d
+
+query TTT rowsort
+select x, y, z from t3 union all by name select z, y, x from t3;
+----
+a b c
+a b c
+
+query TTT rowsort
+select x, y, z from t3 union all by name select z, y, x from t4;
+----
+a b c
+a b c
+
+query TTT
+select x, y, z from t3 union all by name select z, y, x from t4 order by x;
+----
+a b c
+a b c
+
+
+# FIXME: The following should pass without error, but currently it is failing
+# due to differing record batch schemas when the SLT runner collects results.
+# This is due to the following issue: 
https://github.com/apache/datafusion/issues/15394#issue-2943811768
+#
+# More context can be found here: 
https://github.com/apache/datafusion/pull/15242#issuecomment-2746563234
+query error
+select x, y, z from t3 union all by name select z, y, x, 'd' as zz from t3;
+----
+DataFusion error: Internal error: Schema mismatch. Previously had
+Schema {
+    fields: [
+        Field {
+            name: "x",
+            data_type: Utf8,
+            nullable: true,
+            dict_id: 0,
+            dict_is_ordered: false,
+            metadata: {},
+        },
+        Field {
+            name: "y",
+            data_type: Utf8,
+            nullable: true,
+            dict_id: 0,
+            dict_is_ordered: false,
+            metadata: {},
+        },
+        Field {
+            name: "z",
+            data_type: Utf8,
+            nullable: true,
+            dict_id: 0,
+            dict_is_ordered: false,
+            metadata: {},
+        },
+        Field {
+            name: "zz",
+            data_type: Utf8,
+            nullable: false,
+            dict_id: 0,
+            dict_is_ordered: false,
+            metadata: {},
+        },
+    ],
+    metadata: {},
+}
+
+Got:
+Schema {
+    fields: [
+        Field {
+            name: "x",
+            data_type: Utf8,
+            nullable: true,
+            dict_id: 0,
+            dict_is_ordered: false,
+            metadata: {},
+        },
+        Field {
+            name: "y",
+            data_type: Utf8,
+            nullable: true,
+            dict_id: 0,
+            dict_is_ordered: false,
+            metadata: {},
+        },
+        Field {
+            name: "z",
+            data_type: Utf8,
+            nullable: true,
+            dict_id: 0,
+            dict_is_ordered: false,
+            metadata: {},
+        },
+        Field {
+            name: "zz",
+            data_type: Utf8,
+            nullable: true,
+            dict_id: 0,
+            dict_is_ordered: false,
+            metadata: {},
+        },
+    ],
+    metadata: {},
+}.
+This was likely caused by a bug in DataFusion's code and we would welcome that 
you file an bug report in our issue tracker


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to