This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 448dff519c Fix `ScalarValue` handling of NULL values for ListArray 
(#7969)
448dff519c is described below

commit 448dff519c2c46189b5909b542f15df36766544b
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Mon Oct 30 13:23:00 2023 -0700

    Fix `ScalarValue` handling of NULL values for ListArray (#7969)
    
    * Fix try_from_array data type for NULL value in ListArray
    
    * Fix
    
    * Explicitly assert the datatype
    
    * For review
---
 datafusion/common/src/scalar.rs              | 125 +++++++++++++++++++++------
 datafusion/sqllogictest/test_files/array.slt |  11 +++
 2 files changed, 111 insertions(+), 25 deletions(-)

diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs
index be24e2b933..f9b0cbdf22 100644
--- a/datafusion/common/src/scalar.rs
+++ b/datafusion/common/src/scalar.rs
@@ -1312,10 +1312,11 @@ impl ScalarValue {
                 Arc::new(ListArray::from_iter_primitive::<$ARRAY_TY, _, _>(
                     scalars.into_iter().map(|x| match x {
                         ScalarValue::List(arr) => {
-                            if 
arr.as_any().downcast_ref::<NullArray>().is_some() {
+                            // `ScalarValue::List` contains a single element 
`ListArray`.
+                            let list_arr = as_list_array(&arr);
+                            if list_arr.is_null(0) {
                                 None
                             } else {
-                                let list_arr = as_list_array(&arr);
                                 let primitive_arr =
                                     
list_arr.values().as_primitive::<$ARRAY_TY>();
                                 Some(
@@ -1339,12 +1340,14 @@ impl ScalarValue {
                 for scalar in scalars.into_iter() {
                     match scalar {
                         ScalarValue::List(arr) => {
-                            if 
arr.as_any().downcast_ref::<NullArray>().is_some() {
+                            // `ScalarValue::List` contains a single element 
`ListArray`.
+                            let list_arr = as_list_array(&arr);
+
+                            if list_arr.is_null(0) {
                                 builder.append(false);
                                 continue;
                             }
 
-                            let list_arr = as_list_array(&arr);
                             let string_arr = $STRING_ARRAY(list_arr.values());
 
                             for v in string_arr.iter() {
@@ -1699,15 +1702,16 @@ impl ScalarValue {
 
         for scalar in scalars {
             if let ScalarValue::List(arr) = scalar {
-                // i.e. NullArray(1)
-                if arr.as_any().downcast_ref::<NullArray>().is_some() {
+                // `ScalarValue::List` contains a single element `ListArray`.
+                let list_arr = as_list_array(&arr);
+
+                if list_arr.is_null(0) {
                     // Repeat previous offset index
                     offsets.push(0);
 
                     // Element is null
                     valid.append(false);
                 } else {
-                    let list_arr = as_list_array(&arr);
                     let arr = list_arr.values().to_owned();
                     offsets.push(arr.len());
                     elements.push(arr);
@@ -2234,28 +2238,20 @@ impl ScalarValue {
             }
             DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8),
             DataType::LargeUtf8 => typed_cast!(array, index, LargeStringArray, 
LargeUtf8),
-            DataType::List(nested_type) => {
+            DataType::List(_) => {
                 let list_array = as_list_array(array);
-                let arr = match list_array.is_null(index) {
-                    true => new_null_array(nested_type.data_type(), 0),
-                    false => {
-                        let nested_array = list_array.value(index);
-                        Arc::new(wrap_into_list_array(nested_array))
-                    }
-                };
+                let nested_array = list_array.value(index);
+                // Produces a single element `ListArray` with the value at 
`index`.
+                let arr = Arc::new(wrap_into_list_array(nested_array));
 
                 ScalarValue::List(arr)
             }
             // TODO: There is no test for FixedSizeList now, add it later
-            DataType::FixedSizeList(nested_type, _len) => {
+            DataType::FixedSizeList(_, _) => {
                 let list_array = as_fixed_size_list_array(array)?;
-                let arr = match list_array.is_null(index) {
-                    true => new_null_array(nested_type.data_type(), 0),
-                    false => {
-                        let nested_array = list_array.value(index);
-                        Arc::new(wrap_into_list_array(nested_array))
-                    }
-                };
+                let nested_array = list_array.value(index);
+                // Produces a single element `ListArray` with the value at 
`index`.
+                let arr = Arc::new(wrap_into_list_array(nested_array));
 
                 ScalarValue::List(arr)
             }
@@ -2944,8 +2940,15 @@ impl TryFrom<&DataType> for ScalarValue {
                 index_type.clone(),
                 Box::new(value_type.as_ref().try_into()?),
             ),
-            DataType::List(_) => 
ScalarValue::List(new_null_array(&DataType::Null, 0)),
-
+            // `ScalaValue::List` contains single element `ListArray`.
+            DataType::List(field) => ScalarValue::List(new_null_array(
+                &DataType::List(Arc::new(Field::new(
+                    "item",
+                    field.data_type().clone(),
+                    true,
+                ))),
+                1,
+            )),
             DataType::Struct(fields) => ScalarValue::Struct(None, 
fields.clone()),
             DataType::Null => ScalarValue::Null,
             _ => {
@@ -3885,6 +3888,78 @@ mod tests {
         );
     }
 
+    #[test]
+    fn scalar_try_from_array_list_array_null() {
+        let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+            Some(vec![Some(1), Some(2)]),
+            None,
+        ]);
+
+        let non_null_list_scalar = ScalarValue::try_from_array(&list, 
0).unwrap();
+        let null_list_scalar = ScalarValue::try_from_array(&list, 1).unwrap();
+
+        let data_type =
+            DataType::List(Arc::new(Field::new("item", DataType::Int32, 
true)));
+
+        assert_eq!(non_null_list_scalar.data_type(), data_type.clone());
+        assert_eq!(null_list_scalar.data_type(), data_type);
+    }
+
+    #[test]
+    fn scalar_try_from_list() {
+        let data_type =
+            DataType::List(Arc::new(Field::new("item", DataType::Int32, 
true)));
+        let data_type = &data_type;
+        let scalar: ScalarValue = data_type.try_into().unwrap();
+
+        let expected = ScalarValue::List(new_null_array(
+            &DataType::List(Arc::new(Field::new("item", DataType::Int32, 
true))),
+            1,
+        ));
+
+        assert_eq!(expected, scalar)
+    }
+
+    #[test]
+    fn scalar_try_from_list_of_list() {
+        let data_type = DataType::List(Arc::new(Field::new(
+            "item",
+            DataType::List(Arc::new(Field::new("item", DataType::Int32, 
true))),
+            true,
+        )));
+        let data_type = &data_type;
+        let scalar: ScalarValue = data_type.try_into().unwrap();
+
+        let expected = ScalarValue::List(new_null_array(
+            &DataType::List(Arc::new(Field::new(
+                "item",
+                DataType::List(Arc::new(Field::new("item", DataType::Int32, 
true))),
+                true,
+            ))),
+            1,
+        ));
+
+        assert_eq!(expected, scalar)
+    }
+
+    #[test]
+    fn scalar_try_from_not_equal_list_nested_list() {
+        let list_data_type =
+            DataType::List(Arc::new(Field::new("item", DataType::Int32, 
true)));
+        let data_type = &list_data_type;
+        let list_scalar: ScalarValue = data_type.try_into().unwrap();
+
+        let nested_list_data_type = DataType::List(Arc::new(Field::new(
+            "item",
+            DataType::List(Arc::new(Field::new("item", DataType::Int32, 
true))),
+            true,
+        )));
+        let data_type = &nested_list_data_type;
+        let nested_list_scalar: ScalarValue = data_type.try_into().unwrap();
+
+        assert_ne!(list_scalar, nested_list_scalar);
+    }
+
     #[test]
     fn scalar_try_from_dict_datatype() {
         let data_type =
diff --git a/datafusion/sqllogictest/test_files/array.slt 
b/datafusion/sqllogictest/test_files/array.slt
index 621cb4a8f4..b5601a2222 100644
--- a/datafusion/sqllogictest/test_files/array.slt
+++ b/datafusion/sqllogictest/test_files/array.slt
@@ -209,6 +209,17 @@ AS VALUES
   (make_array([28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 
33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]), 
[28, 29, 30], [37, 38, 39], 10)
 ;
 
+query TTT
+select arrow_typeof(column1), arrow_typeof(column2), arrow_typeof(column3) 
from arrays;
+----
+List(Field { name: "item", data_type: List(Field { name: "item", data_type: 
Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), 
nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field 
{ name: "item", data_type: Float64, nullable: true, dict_id: 0, 
dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: 
Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
+List(Field { name: "item", data_type: List(Field { name: "item", data_type: 
Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), 
nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field 
{ name: "item", data_type: Float64, nullable: true, dict_id: 0, 
dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: 
Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
+List(Field { name: "item", data_type: List(Field { name: "item", data_type: 
Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), 
nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field 
{ name: "item", data_type: Float64, nullable: true, dict_id: 0, 
dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: 
Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
+List(Field { name: "item", data_type: List(Field { name: "item", data_type: 
Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), 
nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field 
{ name: "item", data_type: Float64, nullable: true, dict_id: 0, 
dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: 
Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
+List(Field { name: "item", data_type: List(Field { name: "item", data_type: 
Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), 
nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field 
{ name: "item", data_type: Float64, nullable: true, dict_id: 0, 
dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: 
Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
+List(Field { name: "item", data_type: List(Field { name: "item", data_type: 
Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), 
nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field 
{ name: "item", data_type: Float64, nullable: true, dict_id: 0, 
dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: 
Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
+List(Field { name: "item", data_type: List(Field { name: "item", data_type: 
Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), 
nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) List(Field 
{ name: "item", data_type: Float64, nullable: true, dict_id: 0, 
dict_is_ordered: false, metadata: {} }) List(Field { name: "item", data_type: 
Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })
+
 # arrays table
 query ???
 select column1, column2, column3 from arrays;

Reply via email to