This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 8bab267  Add ScalarValue support for arbitrary list elements (#1142)
8bab267 is described below

commit 8bab2676e070ee3cfc55d2ec0877c724d4daf568
Author: Jon Mease <jon.mease....@gmail.com>
AuthorDate: Wed Oct 20 06:45:48 2021 -0400

    Add ScalarValue support for arbitrary list elements (#1142)
    
    * clippy fix
    
    * clippy fixes
    
    * Rebase and review cleanup
---
 datafusion/src/scalar.rs | 349 +++++++++++++++++++++++++++++++++++++++++++++--
 1 file changed, 339 insertions(+), 10 deletions(-)

diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs
index 31c48a6..00586bf 100644
--- a/datafusion/src/scalar.rs
+++ b/datafusion/src/scalar.rs
@@ -777,6 +777,11 @@ impl ScalarValue {
             DataType::List(fields) if fields.data_type() == 
&DataType::LargeUtf8 => {
                 build_array_list_string!(LargeStringBuilder, LargeUtf8)
             }
+            DataType::List(_) => {
+                // Fallback case handling homogeneous lists with any 
ScalarValue element type
+                let list_array = ScalarValue::iter_to_array_list(scalars, 
&data_type)?;
+                Arc::new(list_array)
+            }
             DataType::Struct(fields) => {
                 // Initialize a Vector to store the ScalarValues for each 
column
                 let mut columns: Vec<Vec<ScalarValue>> =
@@ -833,6 +838,73 @@ impl ScalarValue {
         Ok(array)
     }
 
+    fn iter_to_array_list(
+        scalars: impl IntoIterator<Item = ScalarValue>,
+        data_type: &DataType,
+    ) -> Result<GenericListArray<i32>> {
+        let mut offsets = Int32Array::builder(0);
+        if let Err(err) = offsets.append_value(0) {
+            return Err(DataFusionError::ArrowError(err));
+        }
+
+        let mut elements: Vec<ArrayRef> = Vec::new();
+        let mut valid = BooleanBufferBuilder::new(0);
+        let mut flat_len = 0i32;
+        for scalar in scalars {
+            if let ScalarValue::List(values, _) = scalar {
+                match values {
+                    Some(values) => {
+                        let element_array = 
ScalarValue::iter_to_array(*values)?;
+
+                        // Add new offset index
+                        flat_len += element_array.len() as i32;
+                        if let Err(err) = offsets.append_value(flat_len) {
+                            return Err(DataFusionError::ArrowError(err));
+                        }
+
+                        elements.push(element_array);
+
+                        // Element is valid
+                        valid.append(true);
+                    }
+                    None => {
+                        // Repeat previous offset index
+                        if let Err(err) = offsets.append_value(flat_len) {
+                            return Err(DataFusionError::ArrowError(err));
+                        }
+
+                        // Element is null
+                        valid.append(false);
+                    }
+                }
+            } else {
+                return Err(DataFusionError::Internal(format!(
+                    "Expected ScalarValue::List element. Received {:?}",
+                    scalar
+                )));
+            }
+        }
+
+        // Concatenate element arrays to create single flat array
+        let element_arrays: Vec<&dyn Array> =
+            elements.iter().map(|a| a.as_ref()).collect();
+        let flat_array = match arrow::compute::concat(&element_arrays) {
+            Ok(flat_array) => flat_array,
+            Err(err) => return Err(DataFusionError::ArrowError(err)),
+        };
+
+        // Build ListArray using ArrayData so we can specify a flat inner 
array, and offset indices
+        let offsets_array = offsets.finish();
+        let array_data = ArrayDataBuilder::new(data_type.clone())
+            .len(offsets_array.len() - 1)
+            .null_bit_buffer(valid.finish())
+            .add_buffer(offsets_array.data().buffers()[0].clone())
+            .add_child_data(flat_array.data().clone());
+
+        let list_array = ListArray::from(array_data.build()?);
+        Ok(list_array)
+    }
+
     /// Converts a scalar value into an array of `size` rows.
     pub fn to_array_of_size(&self, size: usize) -> ArrayRef {
         match self {
@@ -945,7 +1017,15 @@ impl ScalarValue {
                 &DataType::LargeUtf8 => {
                     build_list!(LargeStringBuilder, LargeUtf8, values, size)
                 }
-                dt => panic!("Unexpected DataType for list {:?}", dt),
+                _ => ScalarValue::iter_to_array_list(
+                    repeat(self.clone()).take(size),
+                    &DataType::List(Box::new(Field::new(
+                        "item",
+                        data_type.as_ref().clone(),
+                        true,
+                    ))),
+                )
+                .unwrap(),
             }),
             ScalarValue::Date32(e) => {
                 build_array_from_option!(Date32, Date32Array, e, size)
@@ -2252,14 +2332,15 @@ mod tests {
     }
 
     #[test]
-    fn test_scalar_list_in_struct() {
+    fn test_lists_in_struct() {
         let field_a = Field::new("A", DataType::Utf8, false);
-        let field_list = Field::new(
-            "list_field",
+        let field_primitive_list = Field::new(
+            "primitive_list",
             DataType::List(Box::new(Field::new("item", DataType::Int32, 
true))),
             false,
         );
 
+        // Define primitive list scalars
         let l0 = ScalarValue::List(
             Some(Box::new(vec![
                 ScalarValue::from(1i32),
@@ -2282,31 +2363,34 @@ mod tests {
             Box::new(DataType::Int32),
         );
 
+        // Define struct scalars
         let s0 = ScalarValue::from(vec![
             ("A", ScalarValue::Utf8(Some(String::from("First")))),
-            ("list_field", l0),
+            ("primitive_list", l0),
         ]);
 
         let s1 = ScalarValue::from(vec![
             ("A", ScalarValue::Utf8(Some(String::from("Second")))),
-            ("list_field", l1),
+            ("primitive_list", l1),
         ]);
 
         let s2 = ScalarValue::from(vec![
             ("A", ScalarValue::Utf8(Some(String::from("Third")))),
-            ("list_field", l2),
+            ("primitive_list", l2),
         ]);
 
-        let array = ScalarValue::iter_to_array(vec![s0, s1, s2]).unwrap();
+        // iter_to_array for struct scalars
+        let array =
+            ScalarValue::iter_to_array(vec![s0.clone(), s1.clone(), 
s2.clone()]).unwrap();
         let array = array.as_any().downcast_ref::<StructArray>().unwrap();
 
         let expected = StructArray::from(vec![
             (
-                field_a,
+                field_a.clone(),
                 Arc::new(StringArray::from(vec!["First", "Second", "Third"])) 
as ArrayRef,
             ),
             (
-                field_list,
+                field_primitive_list.clone(),
                 Arc::new(ListArray::from_iter_primitive::<Int32Type, _, 
_>(vec![
                     Some(vec![Some(1), Some(2), Some(3)]),
                     Some(vec![Some(4), Some(5)]),
@@ -2316,5 +2400,250 @@ mod tests {
         ]);
 
         assert_eq!(array, &expected);
+
+        // Define list-of-structs scalars
+        let nl0 = ScalarValue::List(
+            Some(Box::new(vec![s0.clone(), s1.clone()])),
+            Box::new(s0.get_datatype()),
+        );
+
+        let nl1 =
+            ScalarValue::List(Some(Box::new(vec![s2])), 
Box::new(s0.get_datatype()));
+
+        let nl2 =
+            ScalarValue::List(Some(Box::new(vec![s1])), 
Box::new(s0.get_datatype()));
+
+        // iter_to_array for list-of-struct
+        let array = ScalarValue::iter_to_array(vec![nl0, nl1, nl2]).unwrap();
+        let array = array.as_any().downcast_ref::<ListArray>().unwrap();
+
+        // Construct expected array with array builders
+        let field_a_builder = StringBuilder::new(4);
+        let primitive_value_builder = Int32Array::builder(8);
+        let field_primitive_list_builder = 
ListBuilder::new(primitive_value_builder);
+
+        let element_builder = StructBuilder::new(
+            vec![field_a, field_primitive_list],
+            vec![
+                Box::new(field_a_builder),
+                Box::new(field_primitive_list_builder),
+            ],
+        );
+        let mut list_builder = ListBuilder::new(element_builder);
+
+        list_builder
+            .values()
+            .field_builder::<StringBuilder>(0)
+            .unwrap()
+            .append_value("First")
+            .unwrap();
+        list_builder
+            .values()
+            .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1)
+            .unwrap()
+            .values()
+            .append_value(1)
+            .unwrap();
+        list_builder
+            .values()
+            .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1)
+            .unwrap()
+            .values()
+            .append_value(2)
+            .unwrap();
+        list_builder
+            .values()
+            .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1)
+            .unwrap()
+            .values()
+            .append_value(3)
+            .unwrap();
+        list_builder
+            .values()
+            .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1)
+            .unwrap()
+            .append(true)
+            .unwrap();
+        list_builder.values().append(true).unwrap();
+
+        list_builder
+            .values()
+            .field_builder::<StringBuilder>(0)
+            .unwrap()
+            .append_value("Second")
+            .unwrap();
+        list_builder
+            .values()
+            .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1)
+            .unwrap()
+            .values()
+            .append_value(4)
+            .unwrap();
+        list_builder
+            .values()
+            .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1)
+            .unwrap()
+            .values()
+            .append_value(5)
+            .unwrap();
+        list_builder
+            .values()
+            .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1)
+            .unwrap()
+            .append(true)
+            .unwrap();
+        list_builder.values().append(true).unwrap();
+        list_builder.append(true).unwrap();
+
+        list_builder
+            .values()
+            .field_builder::<StringBuilder>(0)
+            .unwrap()
+            .append_value("Third")
+            .unwrap();
+        list_builder
+            .values()
+            .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1)
+            .unwrap()
+            .values()
+            .append_value(6)
+            .unwrap();
+        list_builder
+            .values()
+            .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1)
+            .unwrap()
+            .append(true)
+            .unwrap();
+        list_builder.values().append(true).unwrap();
+        list_builder.append(true).unwrap();
+
+        list_builder
+            .values()
+            .field_builder::<StringBuilder>(0)
+            .unwrap()
+            .append_value("Second")
+            .unwrap();
+        list_builder
+            .values()
+            .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1)
+            .unwrap()
+            .values()
+            .append_value(4)
+            .unwrap();
+        list_builder
+            .values()
+            .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1)
+            .unwrap()
+            .values()
+            .append_value(5)
+            .unwrap();
+        list_builder
+            .values()
+            .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1)
+            .unwrap()
+            .append(true)
+            .unwrap();
+        list_builder.values().append(true).unwrap();
+        list_builder.append(true).unwrap();
+
+        let expected = list_builder.finish();
+
+        assert_eq!(array, &expected);
+    }
+
+    #[test]
+    fn test_nested_lists() {
+        // Define inner list scalars
+        let l1 = ScalarValue::List(
+            Some(Box::new(vec![
+                ScalarValue::List(
+                    Some(Box::new(vec![
+                        ScalarValue::from(1i32),
+                        ScalarValue::from(2i32),
+                        ScalarValue::from(3i32),
+                    ])),
+                    Box::new(DataType::Int32),
+                ),
+                ScalarValue::List(
+                    Some(Box::new(vec![
+                        ScalarValue::from(4i32),
+                        ScalarValue::from(5i32),
+                    ])),
+                    Box::new(DataType::Int32),
+                ),
+            ])),
+            Box::new(DataType::List(Box::new(Field::new(
+                "item",
+                DataType::Int32,
+                true,
+            )))),
+        );
+
+        let l2 = ScalarValue::List(
+            Some(Box::new(vec![
+                ScalarValue::List(
+                    Some(Box::new(vec![ScalarValue::from(6i32)])),
+                    Box::new(DataType::Int32),
+                ),
+                ScalarValue::List(
+                    Some(Box::new(vec![
+                        ScalarValue::from(7i32),
+                        ScalarValue::from(8i32),
+                    ])),
+                    Box::new(DataType::Int32),
+                ),
+            ])),
+            Box::new(DataType::List(Box::new(Field::new(
+                "item",
+                DataType::Int32,
+                true,
+            )))),
+        );
+
+        let l3 = ScalarValue::List(
+            Some(Box::new(vec![ScalarValue::List(
+                Some(Box::new(vec![ScalarValue::from(9i32)])),
+                Box::new(DataType::Int32),
+            )])),
+            Box::new(DataType::List(Box::new(Field::new(
+                "item",
+                DataType::Int32,
+                true,
+            )))),
+        );
+
+        let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap();
+        let array = array.as_any().downcast_ref::<ListArray>().unwrap();
+
+        // Construct expected array with array builders
+        let inner_builder = Int32Array::builder(8);
+        let middle_builder = ListBuilder::new(inner_builder);
+        let mut outer_builder = ListBuilder::new(middle_builder);
+
+        outer_builder.values().values().append_value(1).unwrap();
+        outer_builder.values().values().append_value(2).unwrap();
+        outer_builder.values().values().append_value(3).unwrap();
+        outer_builder.values().append(true).unwrap();
+
+        outer_builder.values().values().append_value(4).unwrap();
+        outer_builder.values().values().append_value(5).unwrap();
+        outer_builder.values().append(true).unwrap();
+        outer_builder.append(true).unwrap();
+
+        outer_builder.values().values().append_value(6).unwrap();
+        outer_builder.values().append(true).unwrap();
+
+        outer_builder.values().values().append_value(7).unwrap();
+        outer_builder.values().values().append_value(8).unwrap();
+        outer_builder.values().append(true).unwrap();
+        outer_builder.append(true).unwrap();
+
+        outer_builder.values().values().append_value(9).unwrap();
+        outer_builder.values().append(true).unwrap();
+        outer_builder.append(true).unwrap();
+
+        let expected = outer_builder.finish();
+
+        assert_eq!(array, &expected);
     }
 }

Reply via email to