This is an automated email from the ASF dual-hosted git repository.

jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 3b6aac2fce Support struct coercion in `type_union_resolution` (#12839)
3b6aac2fce is described below

commit 3b6aac2fcecdb003427f9475f061ed2cc52e8558
Author: Jay Zhan <[email protected]>
AuthorDate: Fri Oct 11 13:13:36 2024 +0800

    Support struct coercion in `type_union_resolution` (#12839)
    
    * support strucy
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fix struct
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * rm todo
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * add more test
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fix field order
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * add lsit of stuct test
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * upd err msg
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fmt
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    ---------
    
    Signed-off-by: jayzhan211 <[email protected]>
---
 datafusion/expr-common/src/type_coercion/binary.rs |  46 ++++++++-
 datafusion/expr/src/type_coercion/functions.rs     |  39 +++++---
 datafusion/functions-nested/src/make_array.rs      |  48 +++++++++
 datafusion/sqllogictest/test_files/struct.slt      | 109 ++++++++++++++++++++-
 4 files changed, 228 insertions(+), 14 deletions(-)

diff --git a/datafusion/expr-common/src/type_coercion/binary.rs 
b/datafusion/expr-common/src/type_coercion/binary.rs
index 6d66b8b4df..e042dd5d3a 100644
--- a/datafusion/expr-common/src/type_coercion/binary.rs
+++ b/datafusion/expr-common/src/type_coercion/binary.rs
@@ -25,8 +25,8 @@ use crate::operator::Operator;
 use arrow::array::{new_empty_array, Array};
 use arrow::compute::can_cast_types;
 use arrow::datatypes::{
-    DataType, Field, FieldRef, TimeUnit, DECIMAL128_MAX_PRECISION, 
DECIMAL128_MAX_SCALE,
-    DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
+    DataType, Field, FieldRef, Fields, TimeUnit, DECIMAL128_MAX_PRECISION,
+    DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
 };
 use datafusion_common::{exec_datafusion_err, plan_datafusion_err, plan_err, 
Result};
 
@@ -370,6 +370,8 @@ impl From<&DataType> for TypeCategory {
 /// align with the behavior of Postgres. Therefore, we've made slight 
adjustments to the rules
 /// to better match the behavior of both Postgres and DuckDB. For example, we 
expect adjusted
 /// decimal precision and scale when coercing decimal types.
+///
+/// This function doesn't preserve correct field name and nullability for the 
struct type, we only care about data type.
 pub fn type_union_resolution(data_types: &[DataType]) -> Option<DataType> {
     if data_types.is_empty() {
         return None;
@@ -476,6 +478,46 @@ fn type_union_resolution_coercion(
                 type_union_resolution_coercion(lhs.data_type(), 
rhs.data_type());
             new_item_type.map(|t| DataType::List(Arc::new(Field::new("item", 
t, true))))
         }
+        (DataType::Struct(lhs), DataType::Struct(rhs)) => {
+            if lhs.len() != rhs.len() {
+                return None;
+            }
+
+            // Search the field in the right hand side with the SAME field name
+            fn search_corresponding_coerced_type(
+                lhs_field: &FieldRef,
+                rhs: &Fields,
+            ) -> Option<DataType> {
+                for rhs_field in rhs.iter() {
+                    if lhs_field.name() == rhs_field.name() {
+                        if let Some(t) = type_union_resolution_coercion(
+                            lhs_field.data_type(),
+                            rhs_field.data_type(),
+                        ) {
+                            return Some(t);
+                        } else {
+                            return None;
+                        }
+                    }
+                }
+
+                None
+            }
+
+            let types = lhs
+                .iter()
+                .map(|lhs_field| search_corresponding_coerced_type(lhs_field, 
rhs))
+                .collect::<Option<Vec<_>>>()?;
+
+            let fields = types
+                .into_iter()
+                .enumerate()
+                .map(|(i, datatype)| {
+                    Arc::new(Field::new(format!("c{i}"), datatype, true))
+                })
+                .collect::<Vec<FieldRef>>();
+            Some(DataType::Struct(fields.into()))
+        }
         _ => {
             // numeric coercion is the same as comparison coercion, both find 
the narrowest type
             // that can accommodate both types
diff --git a/datafusion/expr/src/type_coercion/functions.rs 
b/datafusion/expr/src/type_coercion/functions.rs
index 143e00fa40..85f8e20ba4 100644
--- a/datafusion/expr/src/type_coercion/functions.rs
+++ b/datafusion/expr/src/type_coercion/functions.rs
@@ -221,20 +221,37 @@ fn get_valid_types_with_scalar_udf(
     current_types: &[DataType],
     func: &ScalarUDF,
 ) -> Result<Vec<Vec<DataType>>> {
-    let valid_types = match signature {
+    match signature {
         TypeSignature::UserDefined => match func.coerce_types(current_types) {
-            Ok(coerced_types) => vec![coerced_types],
-            Err(e) => return exec_err!("User-defined coercion failed with 
{:?}", e),
+            Ok(coerced_types) => Ok(vec![coerced_types]),
+            Err(e) => exec_err!("User-defined coercion failed with {:?}", e),
         },
-        TypeSignature::OneOf(signatures) => signatures
-            .iter()
-            .filter_map(|t| get_valid_types_with_scalar_udf(t, current_types, 
func).ok())
-            .flatten()
-            .collect::<Vec<_>>(),
-        _ => get_valid_types(signature, current_types)?,
-    };
+        TypeSignature::OneOf(signatures) => {
+            let mut res = vec![];
+            let mut errors = vec![];
+            for sig in signatures {
+                match get_valid_types_with_scalar_udf(sig, current_types, 
func) {
+                    Ok(valid_types) => {
+                        res.extend(valid_types);
+                    }
+                    Err(e) => {
+                        errors.push(e.to_string());
+                    }
+                }
+            }
 
-    Ok(valid_types)
+            // Every signature failed, return the joined error
+            if res.is_empty() {
+                internal_err!(
+                    "Failed to match any signature, errors: {}",
+                    errors.join(",")
+                )
+            } else {
+                Ok(res)
+            }
+        }
+        _ => get_valid_types(signature, current_types),
+    }
 }
 
 fn get_valid_types_with_aggregate_udf(
diff --git a/datafusion/functions-nested/src/make_array.rs 
b/datafusion/functions-nested/src/make_array.rs
index 51fc71e6b0..cafa073f91 100644
--- a/datafusion/functions-nested/src/make_array.rs
+++ b/datafusion/functions-nested/src/make_array.rs
@@ -27,10 +27,12 @@ use arrow_array::{
 use arrow_buffer::OffsetBuffer;
 use arrow_schema::DataType::{LargeList, List, Null};
 use arrow_schema::{DataType, Field};
+use datafusion_common::{exec_err, internal_err};
 use datafusion_common::{plan_err, utils::array_into_list_array_nullable, 
Result};
 use datafusion_expr::binary::type_union_resolution;
 use datafusion_expr::TypeSignature;
 use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
+use itertools::Itertools;
 
 use crate::utils::make_scalar_function;
 
@@ -106,6 +108,32 @@ impl ScalarUDFImpl for MakeArray {
 
     fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
         if let Some(new_type) = type_union_resolution(arg_types) {
+            // TODO: Move the logic to type_union_resolution if this applies 
to other functions as well
+            // Handle struct where we only change the data type but preserve 
the field name and nullability.
+            // Since field name is the key of the struct, so it shouldn't be 
updated to the common column name like "c0" or "c1"
+            let is_struct_and_has_same_key = 
are_all_struct_and_have_same_key(arg_types)?;
+            if is_struct_and_has_same_key {
+                let data_types: Vec<_> = if let DataType::Struct(fields) = 
&arg_types[0] {
+                    fields.iter().map(|f| f.data_type().to_owned()).collect()
+                } else {
+                    return internal_err!("Struct type is checked is the 
previous function, so this should be unreachable");
+                };
+
+                let mut final_struct_types = vec![];
+                for s in arg_types {
+                    let mut new_fields = vec![];
+                    if let DataType::Struct(fields) = s {
+                        for (i, f) in fields.iter().enumerate() {
+                            let field = Arc::unwrap_or_clone(Arc::clone(f))
+                                .with_data_type(data_types[i].to_owned());
+                            new_fields.push(Arc::new(field));
+                        }
+                    }
+                    
final_struct_types.push(DataType::Struct(new_fields.into()))
+                }
+                return Ok(final_struct_types);
+            }
+
             if let DataType::FixedSizeList(field, _) = new_type {
                 Ok(vec![DataType::List(field); arg_types.len()])
             } else if new_type.is_null() {
@@ -123,6 +151,26 @@ impl ScalarUDFImpl for MakeArray {
     }
 }
 
+fn are_all_struct_and_have_same_key(data_types: &[DataType]) -> Result<bool> {
+    let mut keys_string: Option<String> = None;
+    for data_type in data_types {
+        if let DataType::Struct(fields) = data_type {
+            let keys = fields.iter().map(|f| f.name().to_owned()).join(",");
+            if let Some(ref k) = keys_string {
+                if *k != keys {
+                    return exec_err!("Expect same keys for struct type but got 
mismatched pair {} and {}", *k, keys);
+                }
+            } else {
+                keys_string = Some(keys);
+            }
+        } else {
+            return Ok(false);
+        }
+    }
+
+    Ok(true)
+}
+
 // Empty array is a special case that is useful for many other array functions
 pub(super) fn empty_array_type() -> DataType {
     DataType::List(Arc::new(Field::new("item", DataType::Int64, true)))
diff --git a/datafusion/sqllogictest/test_files/struct.slt 
b/datafusion/sqllogictest/test_files/struct.slt
index 67cd7d71fc..b76c78396a 100644
--- a/datafusion/sqllogictest/test_files/struct.slt
+++ b/datafusion/sqllogictest/test_files/struct.slt
@@ -374,6 +374,34 @@ You reached the bottom!
 statement ok
 drop view complex_view;
 
+# struct with different keys r1 and r2 is not valid
+statement ok
+create table t(a struct<r1 varchar, c int>, b struct<r2 varchar, c float>) as 
values (struct('red', 1), struct('blue', 2.3));
+
+# Expect same keys for struct type but got mismatched pair r1,c and r2,c
+query error
+select [a, b] from t;
+
+statement ok
+drop table t;
+
+# struct with the same key
+statement ok
+create table t(a struct<r varchar, c int>, b struct<r varchar, c float>) as 
values (struct('red', 1), struct('blue', 2.3));
+
+query T
+select arrow_typeof([a, b]) from t;
+----
+List(Field { name: "item", data_type: Struct([Field { name: "r", data_type: 
Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field 
{ name: "c", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: 
false, metadata: {} }]), nullable: true, dict_id: 0, dict_is_ordered: false, 
metadata: {} })
+
+query ?
+select [a, b] from t;
+----
+[{r: red, c: 1}, {r: blue, c: 2}]
+
+statement ok
+drop table t;
+
 # Test row alias
 
 query ?
@@ -412,7 +440,6 @@ select * from t;
 ----
 {r: red, b: 2} {r: blue, b: 2.3}
 
-# TODO: Should be coerced to float
 query T
 select arrow_typeof(c1) from t;
 ----
@@ -422,3 +449,83 @@ query T
 select arrow_typeof(c2) from t;
 ----
 Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, 
dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Float32, 
nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])
+
+statement ok
+drop table t;
+
+##################################
+## Test Coalesce with Struct
+##################################
+
+statement ok
+CREATE TABLE t (
+    s1 struct(a int, b varchar),
+    s2 struct(a float, b varchar)
+) AS VALUES
+  (row(1, 'red'), row(1.1, 'string1')),
+  (row(2, 'blue'), row(2.2, 'string2')),
+  (row(3, 'green'), row(33.2, 'string3'))
+;
+
+query ?
+select coalesce(s1) from t;
+----
+{a: 1, b: red}
+{a: 2, b: blue}
+{a: 3, b: green}
+
+# TODO: a's type should be float
+query T
+select arrow_typeof(coalesce(s1)) from t;
+----
+Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, 
dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, 
nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])
+Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, 
dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, 
nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])
+Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, 
dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, 
nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])
+
+statement ok
+drop table t;
+
+statement ok
+CREATE TABLE t (
+    s1 struct(a int, b varchar),
+    s2 struct(a float, b varchar)
+) AS VALUES
+  (row(1, 'red'), row(1.1, 'string1')),
+  (null, row(2.2, 'string2')),
+  (row(3, 'green'), row(33.2, 'string3'))
+;
+
+# TODO: second column should not be null
+query ?
+select coalesce(s1) from t;
+----
+{a: 1, b: red}
+NULL
+{a: 3, b: green}
+
+statement ok
+drop table t;
+
+# row() with incorrect order
+statement error DataFusion error: Arrow error: Cast error: Cannot cast string 
'blue' to value of Float64 type
+create table t(a struct(r varchar, c int), b struct(r varchar, c float)) as 
values 
+    (row('red', 1), row(2.3, 'blue')),
+    (row('purple', 1), row('green', 2.3));
+
+# out of order struct literal
+# TODO: This query should not fail
+statement error DataFusion error: Arrow error: Cast error: Cannot cast string 
'a' to value of Int64 type
+create table t(a struct(r varchar, c int)) as values ({r: 'a', c: 1}), ({c: 2, 
r: 'b'});
+
+##################################
+## Test Array of Struct
+##################################
+
+query ?
+select [{r: 'a', c: 1}, {r: 'b', c: 2}];
+----
+[{r: a, c: 1}, {r: b, c: 2}]
+
+# Can't create a list of struct with different field types
+query error
+select [{r: 'a', c: 1}, {c: 2, r: 'b'}];


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to