This is an automated email from the ASF dual-hosted git repository. alamb pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push: new 8bab267 Add ScalarValue support for arbitrary list elements (#1142) 8bab267 is described below commit 8bab2676e070ee3cfc55d2ec0877c724d4daf568 Author: Jon Mease <jon.mease....@gmail.com> AuthorDate: Wed Oct 20 06:45:48 2021 -0400 Add ScalarValue support for arbitrary list elements (#1142) * clippy fix * clippy fixes * Rebase and review cleanup --- datafusion/src/scalar.rs | 349 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 339 insertions(+), 10 deletions(-) diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index 31c48a6..00586bf 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -777,6 +777,11 @@ impl ScalarValue { DataType::List(fields) if fields.data_type() == &DataType::LargeUtf8 => { build_array_list_string!(LargeStringBuilder, LargeUtf8) } + DataType::List(_) => { + // Fallback case handling homogeneous lists with any ScalarValue element type + let list_array = ScalarValue::iter_to_array_list(scalars, &data_type)?; + Arc::new(list_array) + } DataType::Struct(fields) => { // Initialize a Vector to store the ScalarValues for each column let mut columns: Vec<Vec<ScalarValue>> = @@ -833,6 +838,73 @@ impl ScalarValue { Ok(array) } + fn iter_to_array_list( + scalars: impl IntoIterator<Item = ScalarValue>, + data_type: &DataType, + ) -> Result<GenericListArray<i32>> { + let mut offsets = Int32Array::builder(0); + if let Err(err) = offsets.append_value(0) { + return Err(DataFusionError::ArrowError(err)); + } + + let mut elements: Vec<ArrayRef> = Vec::new(); + let mut valid = BooleanBufferBuilder::new(0); + let mut flat_len = 0i32; + for scalar in scalars { + if let ScalarValue::List(values, _) = scalar { + match values { + Some(values) => { + let element_array = ScalarValue::iter_to_array(*values)?; + + // Add new offset index + flat_len += element_array.len() as i32; + if let Err(err) = offsets.append_value(flat_len) { + return Err(DataFusionError::ArrowError(err)); + } + + elements.push(element_array); + + // Element is valid + valid.append(true); + } + None => { + // Repeat previous offset index + if let Err(err) = offsets.append_value(flat_len) { + return Err(DataFusionError::ArrowError(err)); + } + + // Element is null + valid.append(false); + } + } + } else { + return Err(DataFusionError::Internal(format!( + "Expected ScalarValue::List element. Received {:?}", + scalar + ))); + } + } + + // Concatenate element arrays to create single flat array + let element_arrays: Vec<&dyn Array> = + elements.iter().map(|a| a.as_ref()).collect(); + let flat_array = match arrow::compute::concat(&element_arrays) { + Ok(flat_array) => flat_array, + Err(err) => return Err(DataFusionError::ArrowError(err)), + }; + + // Build ListArray using ArrayData so we can specify a flat inner array, and offset indices + let offsets_array = offsets.finish(); + let array_data = ArrayDataBuilder::new(data_type.clone()) + .len(offsets_array.len() - 1) + .null_bit_buffer(valid.finish()) + .add_buffer(offsets_array.data().buffers()[0].clone()) + .add_child_data(flat_array.data().clone()); + + let list_array = ListArray::from(array_data.build()?); + Ok(list_array) + } + /// Converts a scalar value into an array of `size` rows. pub fn to_array_of_size(&self, size: usize) -> ArrayRef { match self { @@ -945,7 +1017,15 @@ impl ScalarValue { &DataType::LargeUtf8 => { build_list!(LargeStringBuilder, LargeUtf8, values, size) } - dt => panic!("Unexpected DataType for list {:?}", dt), + _ => ScalarValue::iter_to_array_list( + repeat(self.clone()).take(size), + &DataType::List(Box::new(Field::new( + "item", + data_type.as_ref().clone(), + true, + ))), + ) + .unwrap(), }), ScalarValue::Date32(e) => { build_array_from_option!(Date32, Date32Array, e, size) @@ -2252,14 +2332,15 @@ mod tests { } #[test] - fn test_scalar_list_in_struct() { + fn test_lists_in_struct() { let field_a = Field::new("A", DataType::Utf8, false); - let field_list = Field::new( - "list_field", + let field_primitive_list = Field::new( + "primitive_list", DataType::List(Box::new(Field::new("item", DataType::Int32, true))), false, ); + // Define primitive list scalars let l0 = ScalarValue::List( Some(Box::new(vec![ ScalarValue::from(1i32), @@ -2282,31 +2363,34 @@ mod tests { Box::new(DataType::Int32), ); + // Define struct scalars let s0 = ScalarValue::from(vec![ ("A", ScalarValue::Utf8(Some(String::from("First")))), - ("list_field", l0), + ("primitive_list", l0), ]); let s1 = ScalarValue::from(vec![ ("A", ScalarValue::Utf8(Some(String::from("Second")))), - ("list_field", l1), + ("primitive_list", l1), ]); let s2 = ScalarValue::from(vec![ ("A", ScalarValue::Utf8(Some(String::from("Third")))), - ("list_field", l2), + ("primitive_list", l2), ]); - let array = ScalarValue::iter_to_array(vec![s0, s1, s2]).unwrap(); + // iter_to_array for struct scalars + let array = + ScalarValue::iter_to_array(vec![s0.clone(), s1.clone(), s2.clone()]).unwrap(); let array = array.as_any().downcast_ref::<StructArray>().unwrap(); let expected = StructArray::from(vec![ ( - field_a, + field_a.clone(), Arc::new(StringArray::from(vec!["First", "Second", "Third"])) as ArrayRef, ), ( - field_list, + field_primitive_list.clone(), Arc::new(ListArray::from_iter_primitive::<Int32Type, _, _>(vec![ Some(vec![Some(1), Some(2), Some(3)]), Some(vec![Some(4), Some(5)]), @@ -2316,5 +2400,250 @@ mod tests { ]); assert_eq!(array, &expected); + + // Define list-of-structs scalars + let nl0 = ScalarValue::List( + Some(Box::new(vec![s0.clone(), s1.clone()])), + Box::new(s0.get_datatype()), + ); + + let nl1 = + ScalarValue::List(Some(Box::new(vec![s2])), Box::new(s0.get_datatype())); + + let nl2 = + ScalarValue::List(Some(Box::new(vec![s1])), Box::new(s0.get_datatype())); + + // iter_to_array for list-of-struct + let array = ScalarValue::iter_to_array(vec![nl0, nl1, nl2]).unwrap(); + let array = array.as_any().downcast_ref::<ListArray>().unwrap(); + + // Construct expected array with array builders + let field_a_builder = StringBuilder::new(4); + let primitive_value_builder = Int32Array::builder(8); + let field_primitive_list_builder = ListBuilder::new(primitive_value_builder); + + let element_builder = StructBuilder::new( + vec![field_a, field_primitive_list], + vec![ + Box::new(field_a_builder), + Box::new(field_primitive_list_builder), + ], + ); + let mut list_builder = ListBuilder::new(element_builder); + + list_builder + .values() + .field_builder::<StringBuilder>(0) + .unwrap() + .append_value("First") + .unwrap(); + list_builder + .values() + .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1) + .unwrap() + .values() + .append_value(1) + .unwrap(); + list_builder + .values() + .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1) + .unwrap() + .values() + .append_value(2) + .unwrap(); + list_builder + .values() + .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1) + .unwrap() + .values() + .append_value(3) + .unwrap(); + list_builder + .values() + .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1) + .unwrap() + .append(true) + .unwrap(); + list_builder.values().append(true).unwrap(); + + list_builder + .values() + .field_builder::<StringBuilder>(0) + .unwrap() + .append_value("Second") + .unwrap(); + list_builder + .values() + .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1) + .unwrap() + .values() + .append_value(4) + .unwrap(); + list_builder + .values() + .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1) + .unwrap() + .values() + .append_value(5) + .unwrap(); + list_builder + .values() + .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1) + .unwrap() + .append(true) + .unwrap(); + list_builder.values().append(true).unwrap(); + list_builder.append(true).unwrap(); + + list_builder + .values() + .field_builder::<StringBuilder>(0) + .unwrap() + .append_value("Third") + .unwrap(); + list_builder + .values() + .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1) + .unwrap() + .values() + .append_value(6) + .unwrap(); + list_builder + .values() + .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1) + .unwrap() + .append(true) + .unwrap(); + list_builder.values().append(true).unwrap(); + list_builder.append(true).unwrap(); + + list_builder + .values() + .field_builder::<StringBuilder>(0) + .unwrap() + .append_value("Second") + .unwrap(); + list_builder + .values() + .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1) + .unwrap() + .values() + .append_value(4) + .unwrap(); + list_builder + .values() + .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1) + .unwrap() + .values() + .append_value(5) + .unwrap(); + list_builder + .values() + .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1) + .unwrap() + .append(true) + .unwrap(); + list_builder.values().append(true).unwrap(); + list_builder.append(true).unwrap(); + + let expected = list_builder.finish(); + + assert_eq!(array, &expected); + } + + #[test] + fn test_nested_lists() { + // Define inner list scalars + let l1 = ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::from(1i32), + ScalarValue::from(2i32), + ScalarValue::from(3i32), + ])), + Box::new(DataType::Int32), + ), + ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::from(4i32), + ScalarValue::from(5i32), + ])), + Box::new(DataType::Int32), + ), + ])), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), + ); + + let l2 = ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::List( + Some(Box::new(vec![ScalarValue::from(6i32)])), + Box::new(DataType::Int32), + ), + ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::from(7i32), + ScalarValue::from(8i32), + ])), + Box::new(DataType::Int32), + ), + ])), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), + ); + + let l3 = ScalarValue::List( + Some(Box::new(vec![ScalarValue::List( + Some(Box::new(vec![ScalarValue::from(9i32)])), + Box::new(DataType::Int32), + )])), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), + ); + + let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap(); + let array = array.as_any().downcast_ref::<ListArray>().unwrap(); + + // Construct expected array with array builders + let inner_builder = Int32Array::builder(8); + let middle_builder = ListBuilder::new(inner_builder); + let mut outer_builder = ListBuilder::new(middle_builder); + + outer_builder.values().values().append_value(1).unwrap(); + outer_builder.values().values().append_value(2).unwrap(); + outer_builder.values().values().append_value(3).unwrap(); + outer_builder.values().append(true).unwrap(); + + outer_builder.values().values().append_value(4).unwrap(); + outer_builder.values().values().append_value(5).unwrap(); + outer_builder.values().append(true).unwrap(); + outer_builder.append(true).unwrap(); + + outer_builder.values().values().append_value(6).unwrap(); + outer_builder.values().append(true).unwrap(); + + outer_builder.values().values().append_value(7).unwrap(); + outer_builder.values().values().append_value(8).unwrap(); + outer_builder.values().append(true).unwrap(); + outer_builder.append(true).unwrap(); + + outer_builder.values().values().append_value(9).unwrap(); + outer_builder.values().append(true).unwrap(); + outer_builder.append(true).unwrap(); + + let expected = outer_builder.finish(); + + assert_eq!(array, &expected); } }