This is an automated email from the ASF dual-hosted git repository.
jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 3b6aac2fce Support struct coercion in `type_union_resolution` (#12839)
3b6aac2fce is described below
commit 3b6aac2fcecdb003427f9475f061ed2cc52e8558
Author: Jay Zhan <[email protected]>
AuthorDate: Fri Oct 11 13:13:36 2024 +0800
Support struct coercion in `type_union_resolution` (#12839)
* support strucy
Signed-off-by: jayzhan211 <[email protected]>
* fix struct
Signed-off-by: jayzhan211 <[email protected]>
* rm todo
Signed-off-by: jayzhan211 <[email protected]>
* add more test
Signed-off-by: jayzhan211 <[email protected]>
* fix field order
Signed-off-by: jayzhan211 <[email protected]>
* add lsit of stuct test
Signed-off-by: jayzhan211 <[email protected]>
* upd err msg
Signed-off-by: jayzhan211 <[email protected]>
* fmt
Signed-off-by: jayzhan211 <[email protected]>
---------
Signed-off-by: jayzhan211 <[email protected]>
---
datafusion/expr-common/src/type_coercion/binary.rs | 46 ++++++++-
datafusion/expr/src/type_coercion/functions.rs | 39 +++++---
datafusion/functions-nested/src/make_array.rs | 48 +++++++++
datafusion/sqllogictest/test_files/struct.slt | 109 ++++++++++++++++++++-
4 files changed, 228 insertions(+), 14 deletions(-)
diff --git a/datafusion/expr-common/src/type_coercion/binary.rs
b/datafusion/expr-common/src/type_coercion/binary.rs
index 6d66b8b4df..e042dd5d3a 100644
--- a/datafusion/expr-common/src/type_coercion/binary.rs
+++ b/datafusion/expr-common/src/type_coercion/binary.rs
@@ -25,8 +25,8 @@ use crate::operator::Operator;
use arrow::array::{new_empty_array, Array};
use arrow::compute::can_cast_types;
use arrow::datatypes::{
- DataType, Field, FieldRef, TimeUnit, DECIMAL128_MAX_PRECISION,
DECIMAL128_MAX_SCALE,
- DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
+ DataType, Field, FieldRef, Fields, TimeUnit, DECIMAL128_MAX_PRECISION,
+ DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
};
use datafusion_common::{exec_datafusion_err, plan_datafusion_err, plan_err,
Result};
@@ -370,6 +370,8 @@ impl From<&DataType> for TypeCategory {
/// align with the behavior of Postgres. Therefore, we've made slight
adjustments to the rules
/// to better match the behavior of both Postgres and DuckDB. For example, we
expect adjusted
/// decimal precision and scale when coercing decimal types.
+///
+/// This function doesn't preserve correct field name and nullability for the
struct type, we only care about data type.
pub fn type_union_resolution(data_types: &[DataType]) -> Option<DataType> {
if data_types.is_empty() {
return None;
@@ -476,6 +478,46 @@ fn type_union_resolution_coercion(
type_union_resolution_coercion(lhs.data_type(),
rhs.data_type());
new_item_type.map(|t| DataType::List(Arc::new(Field::new("item",
t, true))))
}
+ (DataType::Struct(lhs), DataType::Struct(rhs)) => {
+ if lhs.len() != rhs.len() {
+ return None;
+ }
+
+ // Search the field in the right hand side with the SAME field name
+ fn search_corresponding_coerced_type(
+ lhs_field: &FieldRef,
+ rhs: &Fields,
+ ) -> Option<DataType> {
+ for rhs_field in rhs.iter() {
+ if lhs_field.name() == rhs_field.name() {
+ if let Some(t) = type_union_resolution_coercion(
+ lhs_field.data_type(),
+ rhs_field.data_type(),
+ ) {
+ return Some(t);
+ } else {
+ return None;
+ }
+ }
+ }
+
+ None
+ }
+
+ let types = lhs
+ .iter()
+ .map(|lhs_field| search_corresponding_coerced_type(lhs_field,
rhs))
+ .collect::<Option<Vec<_>>>()?;
+
+ let fields = types
+ .into_iter()
+ .enumerate()
+ .map(|(i, datatype)| {
+ Arc::new(Field::new(format!("c{i}"), datatype, true))
+ })
+ .collect::<Vec<FieldRef>>();
+ Some(DataType::Struct(fields.into()))
+ }
_ => {
// numeric coercion is the same as comparison coercion, both find
the narrowest type
// that can accommodate both types
diff --git a/datafusion/expr/src/type_coercion/functions.rs
b/datafusion/expr/src/type_coercion/functions.rs
index 143e00fa40..85f8e20ba4 100644
--- a/datafusion/expr/src/type_coercion/functions.rs
+++ b/datafusion/expr/src/type_coercion/functions.rs
@@ -221,20 +221,37 @@ fn get_valid_types_with_scalar_udf(
current_types: &[DataType],
func: &ScalarUDF,
) -> Result<Vec<Vec<DataType>>> {
- let valid_types = match signature {
+ match signature {
TypeSignature::UserDefined => match func.coerce_types(current_types) {
- Ok(coerced_types) => vec![coerced_types],
- Err(e) => return exec_err!("User-defined coercion failed with
{:?}", e),
+ Ok(coerced_types) => Ok(vec![coerced_types]),
+ Err(e) => exec_err!("User-defined coercion failed with {:?}", e),
},
- TypeSignature::OneOf(signatures) => signatures
- .iter()
- .filter_map(|t| get_valid_types_with_scalar_udf(t, current_types,
func).ok())
- .flatten()
- .collect::<Vec<_>>(),
- _ => get_valid_types(signature, current_types)?,
- };
+ TypeSignature::OneOf(signatures) => {
+ let mut res = vec![];
+ let mut errors = vec![];
+ for sig in signatures {
+ match get_valid_types_with_scalar_udf(sig, current_types,
func) {
+ Ok(valid_types) => {
+ res.extend(valid_types);
+ }
+ Err(e) => {
+ errors.push(e.to_string());
+ }
+ }
+ }
- Ok(valid_types)
+ // Every signature failed, return the joined error
+ if res.is_empty() {
+ internal_err!(
+ "Failed to match any signature, errors: {}",
+ errors.join(",")
+ )
+ } else {
+ Ok(res)
+ }
+ }
+ _ => get_valid_types(signature, current_types),
+ }
}
fn get_valid_types_with_aggregate_udf(
diff --git a/datafusion/functions-nested/src/make_array.rs
b/datafusion/functions-nested/src/make_array.rs
index 51fc71e6b0..cafa073f91 100644
--- a/datafusion/functions-nested/src/make_array.rs
+++ b/datafusion/functions-nested/src/make_array.rs
@@ -27,10 +27,12 @@ use arrow_array::{
use arrow_buffer::OffsetBuffer;
use arrow_schema::DataType::{LargeList, List, Null};
use arrow_schema::{DataType, Field};
+use datafusion_common::{exec_err, internal_err};
use datafusion_common::{plan_err, utils::array_into_list_array_nullable,
Result};
use datafusion_expr::binary::type_union_resolution;
use datafusion_expr::TypeSignature;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
+use itertools::Itertools;
use crate::utils::make_scalar_function;
@@ -106,6 +108,32 @@ impl ScalarUDFImpl for MakeArray {
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if let Some(new_type) = type_union_resolution(arg_types) {
+ // TODO: Move the logic to type_union_resolution if this applies
to other functions as well
+ // Handle struct where we only change the data type but preserve
the field name and nullability.
+ // Since field name is the key of the struct, so it shouldn't be
updated to the common column name like "c0" or "c1"
+ let is_struct_and_has_same_key =
are_all_struct_and_have_same_key(arg_types)?;
+ if is_struct_and_has_same_key {
+ let data_types: Vec<_> = if let DataType::Struct(fields) =
&arg_types[0] {
+ fields.iter().map(|f| f.data_type().to_owned()).collect()
+ } else {
+ return internal_err!("Struct type is checked is the
previous function, so this should be unreachable");
+ };
+
+ let mut final_struct_types = vec![];
+ for s in arg_types {
+ let mut new_fields = vec![];
+ if let DataType::Struct(fields) = s {
+ for (i, f) in fields.iter().enumerate() {
+ let field = Arc::unwrap_or_clone(Arc::clone(f))
+ .with_data_type(data_types[i].to_owned());
+ new_fields.push(Arc::new(field));
+ }
+ }
+
final_struct_types.push(DataType::Struct(new_fields.into()))
+ }
+ return Ok(final_struct_types);
+ }
+
if let DataType::FixedSizeList(field, _) = new_type {
Ok(vec![DataType::List(field); arg_types.len()])
} else if new_type.is_null() {
@@ -123,6 +151,26 @@ impl ScalarUDFImpl for MakeArray {
}
}
+fn are_all_struct_and_have_same_key(data_types: &[DataType]) -> Result<bool> {
+ let mut keys_string: Option<String> = None;
+ for data_type in data_types {
+ if let DataType::Struct(fields) = data_type {
+ let keys = fields.iter().map(|f| f.name().to_owned()).join(",");
+ if let Some(ref k) = keys_string {
+ if *k != keys {
+ return exec_err!("Expect same keys for struct type but got
mismatched pair {} and {}", *k, keys);
+ }
+ } else {
+ keys_string = Some(keys);
+ }
+ } else {
+ return Ok(false);
+ }
+ }
+
+ Ok(true)
+}
+
// Empty array is a special case that is useful for many other array functions
pub(super) fn empty_array_type() -> DataType {
DataType::List(Arc::new(Field::new("item", DataType::Int64, true)))
diff --git a/datafusion/sqllogictest/test_files/struct.slt
b/datafusion/sqllogictest/test_files/struct.slt
index 67cd7d71fc..b76c78396a 100644
--- a/datafusion/sqllogictest/test_files/struct.slt
+++ b/datafusion/sqllogictest/test_files/struct.slt
@@ -374,6 +374,34 @@ You reached the bottom!
statement ok
drop view complex_view;
+# struct with different keys r1 and r2 is not valid
+statement ok
+create table t(a struct<r1 varchar, c int>, b struct<r2 varchar, c float>) as
values (struct('red', 1), struct('blue', 2.3));
+
+# Expect same keys for struct type but got mismatched pair r1,c and r2,c
+query error
+select [a, b] from t;
+
+statement ok
+drop table t;
+
+# struct with the same key
+statement ok
+create table t(a struct<r varchar, c int>, b struct<r varchar, c float>) as
values (struct('red', 1), struct('blue', 2.3));
+
+query T
+select arrow_typeof([a, b]) from t;
+----
+List(Field { name: "item", data_type: Struct([Field { name: "r", data_type:
Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field
{ name: "c", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered:
false, metadata: {} }]), nullable: true, dict_id: 0, dict_is_ordered: false,
metadata: {} })
+
+query ?
+select [a, b] from t;
+----
+[{r: red, c: 1}, {r: blue, c: 2}]
+
+statement ok
+drop table t;
+
# Test row alias
query ?
@@ -412,7 +440,6 @@ select * from t;
----
{r: red, b: 2} {r: blue, b: 2.3}
-# TODO: Should be coerced to float
query T
select arrow_typeof(c1) from t;
----
@@ -422,3 +449,83 @@ query T
select arrow_typeof(c2) from t;
----
Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0,
dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Float32,
nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])
+
+statement ok
+drop table t;
+
+##################################
+## Test Coalesce with Struct
+##################################
+
+statement ok
+CREATE TABLE t (
+ s1 struct(a int, b varchar),
+ s2 struct(a float, b varchar)
+) AS VALUES
+ (row(1, 'red'), row(1.1, 'string1')),
+ (row(2, 'blue'), row(2.2, 'string2')),
+ (row(3, 'green'), row(33.2, 'string3'))
+;
+
+query ?
+select coalesce(s1) from t;
+----
+{a: 1, b: red}
+{a: 2, b: blue}
+{a: 3, b: green}
+
+# TODO: a's type should be float
+query T
+select arrow_typeof(coalesce(s1)) from t;
+----
+Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0,
dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8,
nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])
+Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0,
dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8,
nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])
+Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0,
dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8,
nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])
+
+statement ok
+drop table t;
+
+statement ok
+CREATE TABLE t (
+ s1 struct(a int, b varchar),
+ s2 struct(a float, b varchar)
+) AS VALUES
+ (row(1, 'red'), row(1.1, 'string1')),
+ (null, row(2.2, 'string2')),
+ (row(3, 'green'), row(33.2, 'string3'))
+;
+
+# TODO: second column should not be null
+query ?
+select coalesce(s1) from t;
+----
+{a: 1, b: red}
+NULL
+{a: 3, b: green}
+
+statement ok
+drop table t;
+
+# row() with incorrect order
+statement error DataFusion error: Arrow error: Cast error: Cannot cast string
'blue' to value of Float64 type
+create table t(a struct(r varchar, c int), b struct(r varchar, c float)) as
values
+ (row('red', 1), row(2.3, 'blue')),
+ (row('purple', 1), row('green', 2.3));
+
+# out of order struct literal
+# TODO: This query should not fail
+statement error DataFusion error: Arrow error: Cast error: Cannot cast string
'a' to value of Int64 type
+create table t(a struct(r varchar, c int)) as values ({r: 'a', c: 1}), ({c: 2,
r: 'b'});
+
+##################################
+## Test Array of Struct
+##################################
+
+query ?
+select [{r: 'a', c: 1}, {r: 'b', c: 2}];
+----
+[{r: a, c: 1}, {r: b, c: 2}]
+
+# Can't create a list of struct with different field types
+query error
+select [{r: 'a', c: 1}, {c: 2, r: 'b'}];
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]