This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 8bab267 Add ScalarValue support for arbitrary list elements (#1142)
8bab267 is described below
commit 8bab2676e070ee3cfc55d2ec0877c724d4daf568
Author: Jon Mease <[email protected]>
AuthorDate: Wed Oct 20 06:45:48 2021 -0400
Add ScalarValue support for arbitrary list elements (#1142)
* clippy fix
* clippy fixes
* Rebase and review cleanup
---
datafusion/src/scalar.rs | 349 +++++++++++++++++++++++++++++++++++++++++++++--
1 file changed, 339 insertions(+), 10 deletions(-)
diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs
index 31c48a6..00586bf 100644
--- a/datafusion/src/scalar.rs
+++ b/datafusion/src/scalar.rs
@@ -777,6 +777,11 @@ impl ScalarValue {
DataType::List(fields) if fields.data_type() ==
&DataType::LargeUtf8 => {
build_array_list_string!(LargeStringBuilder, LargeUtf8)
}
+ DataType::List(_) => {
+ // Fallback case handling homogeneous lists with any
ScalarValue element type
+ let list_array = ScalarValue::iter_to_array_list(scalars,
&data_type)?;
+ Arc::new(list_array)
+ }
DataType::Struct(fields) => {
// Initialize a Vector to store the ScalarValues for each
column
let mut columns: Vec<Vec<ScalarValue>> =
@@ -833,6 +838,73 @@ impl ScalarValue {
Ok(array)
}
+ fn iter_to_array_list(
+ scalars: impl IntoIterator<Item = ScalarValue>,
+ data_type: &DataType,
+ ) -> Result<GenericListArray<i32>> {
+ let mut offsets = Int32Array::builder(0);
+ if let Err(err) = offsets.append_value(0) {
+ return Err(DataFusionError::ArrowError(err));
+ }
+
+ let mut elements: Vec<ArrayRef> = Vec::new();
+ let mut valid = BooleanBufferBuilder::new(0);
+ let mut flat_len = 0i32;
+ for scalar in scalars {
+ if let ScalarValue::List(values, _) = scalar {
+ match values {
+ Some(values) => {
+ let element_array =
ScalarValue::iter_to_array(*values)?;
+
+ // Add new offset index
+ flat_len += element_array.len() as i32;
+ if let Err(err) = offsets.append_value(flat_len) {
+ return Err(DataFusionError::ArrowError(err));
+ }
+
+ elements.push(element_array);
+
+ // Element is valid
+ valid.append(true);
+ }
+ None => {
+ // Repeat previous offset index
+ if let Err(err) = offsets.append_value(flat_len) {
+ return Err(DataFusionError::ArrowError(err));
+ }
+
+ // Element is null
+ valid.append(false);
+ }
+ }
+ } else {
+ return Err(DataFusionError::Internal(format!(
+ "Expected ScalarValue::List element. Received {:?}",
+ scalar
+ )));
+ }
+ }
+
+ // Concatenate element arrays to create single flat array
+ let element_arrays: Vec<&dyn Array> =
+ elements.iter().map(|a| a.as_ref()).collect();
+ let flat_array = match arrow::compute::concat(&element_arrays) {
+ Ok(flat_array) => flat_array,
+ Err(err) => return Err(DataFusionError::ArrowError(err)),
+ };
+
+ // Build ListArray using ArrayData so we can specify a flat inner
array, and offset indices
+ let offsets_array = offsets.finish();
+ let array_data = ArrayDataBuilder::new(data_type.clone())
+ .len(offsets_array.len() - 1)
+ .null_bit_buffer(valid.finish())
+ .add_buffer(offsets_array.data().buffers()[0].clone())
+ .add_child_data(flat_array.data().clone());
+
+ let list_array = ListArray::from(array_data.build()?);
+ Ok(list_array)
+ }
+
/// Converts a scalar value into an array of `size` rows.
pub fn to_array_of_size(&self, size: usize) -> ArrayRef {
match self {
@@ -945,7 +1017,15 @@ impl ScalarValue {
&DataType::LargeUtf8 => {
build_list!(LargeStringBuilder, LargeUtf8, values, size)
}
- dt => panic!("Unexpected DataType for list {:?}", dt),
+ _ => ScalarValue::iter_to_array_list(
+ repeat(self.clone()).take(size),
+ &DataType::List(Box::new(Field::new(
+ "item",
+ data_type.as_ref().clone(),
+ true,
+ ))),
+ )
+ .unwrap(),
}),
ScalarValue::Date32(e) => {
build_array_from_option!(Date32, Date32Array, e, size)
@@ -2252,14 +2332,15 @@ mod tests {
}
#[test]
- fn test_scalar_list_in_struct() {
+ fn test_lists_in_struct() {
let field_a = Field::new("A", DataType::Utf8, false);
- let field_list = Field::new(
- "list_field",
+ let field_primitive_list = Field::new(
+ "primitive_list",
DataType::List(Box::new(Field::new("item", DataType::Int32,
true))),
false,
);
+ // Define primitive list scalars
let l0 = ScalarValue::List(
Some(Box::new(vec![
ScalarValue::from(1i32),
@@ -2282,31 +2363,34 @@ mod tests {
Box::new(DataType::Int32),
);
+ // Define struct scalars
let s0 = ScalarValue::from(vec![
("A", ScalarValue::Utf8(Some(String::from("First")))),
- ("list_field", l0),
+ ("primitive_list", l0),
]);
let s1 = ScalarValue::from(vec![
("A", ScalarValue::Utf8(Some(String::from("Second")))),
- ("list_field", l1),
+ ("primitive_list", l1),
]);
let s2 = ScalarValue::from(vec![
("A", ScalarValue::Utf8(Some(String::from("Third")))),
- ("list_field", l2),
+ ("primitive_list", l2),
]);
- let array = ScalarValue::iter_to_array(vec![s0, s1, s2]).unwrap();
+ // iter_to_array for struct scalars
+ let array =
+ ScalarValue::iter_to_array(vec![s0.clone(), s1.clone(),
s2.clone()]).unwrap();
let array = array.as_any().downcast_ref::<StructArray>().unwrap();
let expected = StructArray::from(vec![
(
- field_a,
+ field_a.clone(),
Arc::new(StringArray::from(vec!["First", "Second", "Third"]))
as ArrayRef,
),
(
- field_list,
+ field_primitive_list.clone(),
Arc::new(ListArray::from_iter_primitive::<Int32Type, _,
_>(vec![
Some(vec![Some(1), Some(2), Some(3)]),
Some(vec![Some(4), Some(5)]),
@@ -2316,5 +2400,250 @@ mod tests {
]);
assert_eq!(array, &expected);
+
+ // Define list-of-structs scalars
+ let nl0 = ScalarValue::List(
+ Some(Box::new(vec![s0.clone(), s1.clone()])),
+ Box::new(s0.get_datatype()),
+ );
+
+ let nl1 =
+ ScalarValue::List(Some(Box::new(vec![s2])),
Box::new(s0.get_datatype()));
+
+ let nl2 =
+ ScalarValue::List(Some(Box::new(vec![s1])),
Box::new(s0.get_datatype()));
+
+ // iter_to_array for list-of-struct
+ let array = ScalarValue::iter_to_array(vec![nl0, nl1, nl2]).unwrap();
+ let array = array.as_any().downcast_ref::<ListArray>().unwrap();
+
+ // Construct expected array with array builders
+ let field_a_builder = StringBuilder::new(4);
+ let primitive_value_builder = Int32Array::builder(8);
+ let field_primitive_list_builder =
ListBuilder::new(primitive_value_builder);
+
+ let element_builder = StructBuilder::new(
+ vec![field_a, field_primitive_list],
+ vec![
+ Box::new(field_a_builder),
+ Box::new(field_primitive_list_builder),
+ ],
+ );
+ let mut list_builder = ListBuilder::new(element_builder);
+
+ list_builder
+ .values()
+ .field_builder::<StringBuilder>(0)
+ .unwrap()
+ .append_value("First")
+ .unwrap();
+ list_builder
+ .values()
+ .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1)
+ .unwrap()
+ .values()
+ .append_value(1)
+ .unwrap();
+ list_builder
+ .values()
+ .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1)
+ .unwrap()
+ .values()
+ .append_value(2)
+ .unwrap();
+ list_builder
+ .values()
+ .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1)
+ .unwrap()
+ .values()
+ .append_value(3)
+ .unwrap();
+ list_builder
+ .values()
+ .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1)
+ .unwrap()
+ .append(true)
+ .unwrap();
+ list_builder.values().append(true).unwrap();
+
+ list_builder
+ .values()
+ .field_builder::<StringBuilder>(0)
+ .unwrap()
+ .append_value("Second")
+ .unwrap();
+ list_builder
+ .values()
+ .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1)
+ .unwrap()
+ .values()
+ .append_value(4)
+ .unwrap();
+ list_builder
+ .values()
+ .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1)
+ .unwrap()
+ .values()
+ .append_value(5)
+ .unwrap();
+ list_builder
+ .values()
+ .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1)
+ .unwrap()
+ .append(true)
+ .unwrap();
+ list_builder.values().append(true).unwrap();
+ list_builder.append(true).unwrap();
+
+ list_builder
+ .values()
+ .field_builder::<StringBuilder>(0)
+ .unwrap()
+ .append_value("Third")
+ .unwrap();
+ list_builder
+ .values()
+ .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1)
+ .unwrap()
+ .values()
+ .append_value(6)
+ .unwrap();
+ list_builder
+ .values()
+ .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1)
+ .unwrap()
+ .append(true)
+ .unwrap();
+ list_builder.values().append(true).unwrap();
+ list_builder.append(true).unwrap();
+
+ list_builder
+ .values()
+ .field_builder::<StringBuilder>(0)
+ .unwrap()
+ .append_value("Second")
+ .unwrap();
+ list_builder
+ .values()
+ .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1)
+ .unwrap()
+ .values()
+ .append_value(4)
+ .unwrap();
+ list_builder
+ .values()
+ .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1)
+ .unwrap()
+ .values()
+ .append_value(5)
+ .unwrap();
+ list_builder
+ .values()
+ .field_builder::<ListBuilder<PrimitiveBuilder<Int32Type>>>(1)
+ .unwrap()
+ .append(true)
+ .unwrap();
+ list_builder.values().append(true).unwrap();
+ list_builder.append(true).unwrap();
+
+ let expected = list_builder.finish();
+
+ assert_eq!(array, &expected);
+ }
+
+ #[test]
+ fn test_nested_lists() {
+ // Define inner list scalars
+ let l1 = ScalarValue::List(
+ Some(Box::new(vec![
+ ScalarValue::List(
+ Some(Box::new(vec![
+ ScalarValue::from(1i32),
+ ScalarValue::from(2i32),
+ ScalarValue::from(3i32),
+ ])),
+ Box::new(DataType::Int32),
+ ),
+ ScalarValue::List(
+ Some(Box::new(vec![
+ ScalarValue::from(4i32),
+ ScalarValue::from(5i32),
+ ])),
+ Box::new(DataType::Int32),
+ ),
+ ])),
+ Box::new(DataType::List(Box::new(Field::new(
+ "item",
+ DataType::Int32,
+ true,
+ )))),
+ );
+
+ let l2 = ScalarValue::List(
+ Some(Box::new(vec![
+ ScalarValue::List(
+ Some(Box::new(vec![ScalarValue::from(6i32)])),
+ Box::new(DataType::Int32),
+ ),
+ ScalarValue::List(
+ Some(Box::new(vec![
+ ScalarValue::from(7i32),
+ ScalarValue::from(8i32),
+ ])),
+ Box::new(DataType::Int32),
+ ),
+ ])),
+ Box::new(DataType::List(Box::new(Field::new(
+ "item",
+ DataType::Int32,
+ true,
+ )))),
+ );
+
+ let l3 = ScalarValue::List(
+ Some(Box::new(vec![ScalarValue::List(
+ Some(Box::new(vec![ScalarValue::from(9i32)])),
+ Box::new(DataType::Int32),
+ )])),
+ Box::new(DataType::List(Box::new(Field::new(
+ "item",
+ DataType::Int32,
+ true,
+ )))),
+ );
+
+ let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap();
+ let array = array.as_any().downcast_ref::<ListArray>().unwrap();
+
+ // Construct expected array with array builders
+ let inner_builder = Int32Array::builder(8);
+ let middle_builder = ListBuilder::new(inner_builder);
+ let mut outer_builder = ListBuilder::new(middle_builder);
+
+ outer_builder.values().values().append_value(1).unwrap();
+ outer_builder.values().values().append_value(2).unwrap();
+ outer_builder.values().values().append_value(3).unwrap();
+ outer_builder.values().append(true).unwrap();
+
+ outer_builder.values().values().append_value(4).unwrap();
+ outer_builder.values().values().append_value(5).unwrap();
+ outer_builder.values().append(true).unwrap();
+ outer_builder.append(true).unwrap();
+
+ outer_builder.values().values().append_value(6).unwrap();
+ outer_builder.values().append(true).unwrap();
+
+ outer_builder.values().values().append_value(7).unwrap();
+ outer_builder.values().values().append_value(8).unwrap();
+ outer_builder.values().append(true).unwrap();
+ outer_builder.append(true).unwrap();
+
+ outer_builder.values().values().append_value(9).unwrap();
+ outer_builder.values().append(true).unwrap();
+ outer_builder.append(true).unwrap();
+
+ let expected = outer_builder.finish();
+
+ assert_eq!(array, &expected);
}
}