This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 448dff519c Fix `ScalarValue` handling of NULL values for ListArray
(#7969)
448dff519c is described below
commit 448dff519c2c46189b5909b542f15df36766544b
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Mon Oct 30 13:23:00 2023 -0700
Fix `ScalarValue` handling of NULL values for ListArray (#7969)
* Fix try_from_array data type for NULL value in ListArray
* Fix
* Explicitly assert the datatype
* For review
---
datafusion/common/src/scalar.rs | 125 +++++++++++++++++++++------
datafusion/sqllogictest/test_files/array.slt | 11 +++
2 files changed, 111 insertions(+), 25 deletions(-)
diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs
index be24e2b933..f9b0cbdf22 100644
--- a/datafusion/common/src/scalar.rs
+++ b/datafusion/common/src/scalar.rs
@@ -1312,10 +1312,11 @@ impl ScalarValue {
Arc::new(ListArray::from_iter_primitive::<$ARRAY_TY, _, _>(
scalars.into_iter().map(|x| match x {
ScalarValue::List(arr) => {
- if
arr.as_any().downcast_ref::<NullArray>().is_some() {
+ // `ScalarValue::List` contains a single element
`ListArray`.
+ let list_arr = as_list_array(&arr);
+ if list_arr.is_null(0) {
None
} else {
- let list_arr = as_list_array(&arr);
let primitive_arr =
list_arr.values().as_primitive::<$ARRAY_TY>();
Some(
@@ -1339,12 +1340,14 @@ impl ScalarValue {
for scalar in scalars.into_iter() {
match scalar {
ScalarValue::List(arr) => {
- if
arr.as_any().downcast_ref::<NullArray>().is_some() {
+ // `ScalarValue::List` contains a single element
`ListArray`.
+ let list_arr = as_list_array(&arr);
+
+ if list_arr.is_null(0) {
builder.append(false);
continue;
}
- let list_arr = as_list_array(&arr);
let string_arr = $STRING_ARRAY(list_arr.values());
for v in string_arr.iter() {
@@ -1699,15 +1702,16 @@ impl ScalarValue {
for scalar in scalars {
if let ScalarValue::List(arr) = scalar {
- // i.e. NullArray(1)
- if arr.as_any().downcast_ref::<NullArray>().is_some() {
+ // `ScalarValue::List` contains a single element `ListArray`.
+ let list_arr = as_list_array(&arr);
+
+ if list_arr.is_null(0) {
// Repeat previous offset index
offsets.push(0);
// Element is null
valid.append(false);
} else {
- let list_arr = as_list_array(&arr);
let arr = list_arr.values().to_owned();
offsets.push(arr.len());
elements.push(arr);
@@ -2234,28 +2238,20 @@ impl ScalarValue {
}
DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8),
DataType::LargeUtf8 => typed_cast!(array, index, LargeStringArray,
LargeUtf8),
- DataType::List(nested_type) => {
+ DataType::List(_) => {
let list_array = as_list_array(array);
- let arr = match list_array.is_null(index) {
- true => new_null_array(nested_type.data_type(), 0),
- false => {
- let nested_array = list_array.value(index);
- Arc::new(wrap_into_list_array(nested_array))
- }
- };
+ let nested_array = list_array.value(index);
+ // Produces a single element `ListArray` with the value at
`index`.
+ let arr = Arc::new(wrap_into_list_array(nested_array));
ScalarValue::List(arr)
}
// TODO: There is no test for FixedSizeList now, add it later
- DataType::FixedSizeList(nested_type, _len) => {
+ DataType::FixedSizeList(_, _) => {
let list_array = as_fixed_size_list_array(array)?;
- let arr = match list_array.is_null(index) {
- true => new_null_array(nested_type.data_type(), 0),
- false => {
- let nested_array = list_array.value(index);
- Arc::new(wrap_into_list_array(nested_array))
- }
- };
+ let nested_array = list_array.value(index);
+ // Produces a single element `ListArray` with the value at
`index`.
+ let arr = Arc::new(wrap_into_list_array(nested_array));
ScalarValue::List(arr)
}
@@ -2944,8 +2940,15 @@ impl TryFrom<&DataType> for ScalarValue {
index_type.clone(),
Box::new(value_type.as_ref().try_into()?),
),
- DataType::List(_) =>
ScalarValue::List(new_null_array(&DataType::Null, 0)),
-
+ // `ScalaValue::List` contains single element `ListArray`.
+ DataType::List(field) => ScalarValue::List(new_null_array(
+ &DataType::List(Arc::new(Field::new(
+ "item",
+ field.data_type().clone(),
+ true,
+ ))),
+ 1,
+ )),
DataType::Struct(fields) => ScalarValue::Struct(None,
fields.clone()),
DataType::Null => ScalarValue::Null,
_ => {
@@ -3885,6 +3888,78 @@ mod tests {
);
}
+ #[test]
+ fn scalar_try_from_array_list_array_null() {
+ let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+ Some(vec![Some(1), Some(2)]),
+ None,
+ ]);
+
+ let non_null_list_scalar = ScalarValue::try_from_array(&list,
0).unwrap();
+ let null_list_scalar = ScalarValue::try_from_array(&list, 1).unwrap();
+
+ let data_type =
+ DataType::List(Arc::new(Field::new("item", DataType::Int32,
true)));
+
+ assert_eq!(non_null_list_scalar.data_type(), data_type.clone());
+ assert_eq!(null_list_scalar.data_type(), data_type);
+ }
+
+ #[test]
+ fn scalar_try_from_list() {
+ let data_type =
+ DataType::List(Arc::new(Field::new("item", DataType::Int32,
true)));
+ let data_type = &data_type;
+ let scalar: ScalarValue = data_type.try_into().unwrap();
+
+ let expected = ScalarValue::List(new_null_array(
+ &DataType::List(Arc::new(Field::new("item", DataType::Int32,
true))),
+ 1,
+ ));
+
+ assert_eq!(expected, scalar)
+ }
+
+ #[test]
+ fn scalar_try_from_list_of_list() {
+ let data_type = DataType::List(Arc::new(Field::new(
+ "item",
+ DataType::List(Arc::new(Field::new("item", DataType::Int32,
true))),
+ true,
+ )));
+ let data_type = &data_type;
+ let scalar: ScalarValue = data_type.try_into().unwrap();
+
+ let expected = ScalarValue::List(new_null_array(
+ &DataType::List(Arc::new(Field::new(
+ "item",
+ DataType::List(Arc::new(Field::new("item", DataType::Int32,
true))),
+ true,
+ ))),
+ 1,
+ ));
+
+ assert_eq!(expected, scalar)
+ }
+
+ #[test]
+ fn scalar_try_from_not_equal_list_nested_list() {
+ let list_data_type =
+ DataType::List(Arc::new(Field::new("item", DataType::Int32,
true)));
+ let data_type = &list_data_type;
+ let list_scalar: ScalarValue = data_type.try_into().unwrap();
+
+ let nested_list_data_type = DataType::List(Arc::new(Field::new(
+ "item",
+ DataType::List(Arc::new(Field::new("item", DataType::Int32,
true))),
+ true,
+ )));
+ let data_type = &nested_list_data_type;
+ let nested_list_scalar: ScalarValue = data_type.try_into().unwrap();
+
+ assert_ne!(list_scalar, nested_list_scalar);
+ }
+
#[test]
fn scalar_try_from_dict_datatype() {
let data_type =
diff --git a/datafusion/sqllogictest/test_files/array.slt
b/datafusion/sqllogictest/test_files/array.slt
index 621cb4a8f4..b5601a2222 100644
--- a/datafusion/sqllogictest/test_files/array.slt
+++ b/datafusion/sqllogictest/test_files/array.slt
@@ -209,6 +209,17 @@ AS VALUES
(make_array([28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32,
33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]),
[28, 29, 30], [37, 38, 39], 10)
;
+query TTT
+select arrow_typeof(column1), arrow_typeof(column2), arrow_typeof(column3)
from arrays;
+----
+List(Field { name: "item", data_type: List(Field { name: "item", data_type:
Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }),
nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field
{ name: "item", data_type: Float64, nullable: true, dict_id: 0,
dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type:
Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
+List(Field { name: "item", data_type: List(Field { name: "item", data_type:
Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }),
nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field
{ name: "item", data_type: Float64, nullable: true, dict_id: 0,
dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type:
Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
+List(Field { name: "item", data_type: List(Field { name: "item", data_type:
Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }),
nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field
{ name: "item", data_type: Float64, nullable: true, dict_id: 0,
dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type:
Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
+List(Field { name: "item", data_type: List(Field { name: "item", data_type:
Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }),
nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field
{ name: "item", data_type: Float64, nullable: true, dict_id: 0,
dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type:
Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
+List(Field { name: "item", data_type: List(Field { name: "item", data_type:
Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }),
nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field
{ name: "item", data_type: Float64, nullable: true, dict_id: 0,
dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type:
Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
+List(Field { name: "item", data_type: List(Field { name: "item", data_type:
Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }),
nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field
{ name: "item", data_type: Float64, nullable: true, dict_id: 0,
dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type:
Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
+List(Field { name: "item", data_type: List(Field { name: "item", data_type:
Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }),
nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field
{ name: "item", data_type: Float64, nullable: true, dict_id: 0,
dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type:
Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
+
# arrays table
query ???
select column1, column2, column3 from arrays;