This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 6a7dbbb Improve GetIndexedFieldExpr adding utf8 key based access for
struct v… (#1204)
6a7dbbb is described below
commit 6a7dbbb848880e4bd8a013b85aa38780819622ec
Author: Guillaume Balaine <[email protected]>
AuthorDate: Tue Nov 2 21:08:38 2021 +0100
Improve GetIndexedFieldExpr adding utf8 key based access for struct v…
(#1204)
* Improve GetIndexedFieldExpr adding utf8 key based access for struct values
* fix clippies
---
datafusion/src/field_util.rs | 21 ++-
datafusion/src/logical_plan/expr.rs | 2 +-
.../physical_plan/expressions/get_indexed_field.rs | 167 +++++++++++++++++++--
datafusion/tests/sql.rs | 44 ++++++
4 files changed, 220 insertions(+), 14 deletions(-)
diff --git a/datafusion/src/field_util.rs b/datafusion/src/field_util.rs
index 9d5face..272c17b 100644
--- a/datafusion/src/field_util.rs
+++ b/datafusion/src/field_util.rs
@@ -22,7 +22,7 @@ use arrow::datatypes::{DataType, Field};
use crate::error::{DataFusionError, Result};
use crate::scalar::ScalarValue;
-/// Returns the field access indexed by `key` from a [`DataType::List`]
+/// Returns the field access indexed by `key` from a [`DataType::List`] or
[`DataType::Struct`]
/// # Error
/// Errors if
/// * the `data_type` is not a Struct or,
@@ -39,6 +39,25 @@ pub fn get_indexed_field(data_type: &DataType, key:
&ScalarValue) -> Result<Fiel
Ok(Field::new(&i.to_string(), lt.data_type().clone(), false))
}
}
+ (DataType::Struct(fields), ScalarValue::Utf8(Some(s))) => {
+ if s.is_empty() {
+ Err(DataFusionError::Plan(
+ "Struct based indexed access requires a non empty
string".to_string(),
+ ))
+ } else {
+ let field = fields.iter().find(|f| f.name() == s);
+ match field {
+ None => Err(DataFusionError::Plan(format!(
+ "Field {} not found in struct",
+ s
+ ))),
+ Some(f) => Ok(f.clone()),
+ }
+ }
+ }
+ (DataType::Struct(_), _) => Err(DataFusionError::Plan(
+ "Only utf8 strings are valid as an indexed field in a
struct".to_string(),
+ )),
(DataType::List(_), _) => Err(DataFusionError::Plan(
"Only ints are valid as an indexed field in a list".to_string(),
)),
diff --git a/datafusion/src/logical_plan/expr.rs
b/datafusion/src/logical_plan/expr.rs
index 19e6fe3..318d73f 100644
--- a/datafusion/src/logical_plan/expr.rs
+++ b/datafusion/src/logical_plan/expr.rs
@@ -246,7 +246,7 @@ pub enum Expr {
IsNull(Box<Expr>),
/// arithmetic negation of an expression, the operand must be of a signed
numeric data type
Negative(Box<Expr>),
- /// Returns the field of a [`ListArray`] by key
+ /// Returns the field of a [`ListArray`] or [`StructArray`] by key
GetIndexedField {
/// the expression to take the field from
expr: Box<Expr>,
diff --git a/datafusion/src/physical_plan/expressions/get_indexed_field.rs
b/datafusion/src/physical_plan/expressions/get_indexed_field.rs
index 8a9191e..7e60698 100644
--- a/datafusion/src/physical_plan/expressions/get_indexed_field.rs
+++ b/datafusion/src/physical_plan/expressions/get_indexed_field.rs
@@ -34,7 +34,7 @@ use crate::{
field_util::get_indexed_field as get_data_type_field,
physical_plan::{ColumnarValue, PhysicalExpr},
};
-use arrow::array::ListArray;
+use arrow::array::{ListArray, StructArray};
use std::fmt::Debug;
/// expression to get a field of a struct array.
@@ -81,7 +81,7 @@ impl PhysicalExpr for GetIndexedFieldExpr {
let arg = self.arg.evaluate(batch)?;
match arg {
ColumnarValue::Array(array) => match (array.data_type(),
&self.key) {
- (DataType::List(_), _) if self.key.is_null() => {
+ (DataType::List(_) | DataType::Struct(_), _) if
self.key.is_null() => {
let scalar_null: ScalarValue =
array.data_type().try_into()?;
Ok(ColumnarValue::Scalar(scalar_null))
}
@@ -100,6 +100,13 @@ impl PhysicalExpr for GetIndexedFieldExpr {
let iter = concat(vec.as_slice()).unwrap();
Ok(ColumnarValue::Array(iter))
}
+ (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => {
+ let as_struct_array =
array.as_any().downcast_ref::<StructArray>().unwrap();
+ match as_struct_array.column_by_name(k) {
+ None => Err(DataFusionError::Execution(format!("get
indexed field {} not found in struct", k))),
+ Some(col) => Ok(ColumnarValue::Array(col.clone()))
+ }
+ }
(dt, key) => Err(DataFusionError::NotImplemented(format!("get
indexed field is only possible on lists with int64 indexes. Tried {} with {}
index", dt, key))),
},
ColumnarValue::Scalar(_) => Err(DataFusionError::NotImplemented(
@@ -112,18 +119,16 @@ impl PhysicalExpr for GetIndexedFieldExpr {
#[cfg(test)]
mod tests {
use super::*;
+ use crate::arrow::array::GenericListArray;
use crate::error::Result;
use crate::physical_plan::expressions::{col, lit};
- use arrow::array::{ListBuilder, StringBuilder};
+ use arrow::array::{
+ Int64Array, Int64Builder, ListBuilder, StringBuilder, StructArray,
StructBuilder,
+ };
use arrow::{array::StringArray, datatypes::Field};
- fn get_indexed_field_test(
- list_of_lists: Vec<Vec<Option<&str>>>,
- index: i64,
- expected: Vec<Option<&str>>,
- ) -> Result<()> {
- let schema = list_schema("l");
- let builder = StringBuilder::new(3);
+ fn build_utf8_lists(list_of_lists: Vec<Vec<Option<&str>>>) ->
GenericListArray<i32> {
+ let builder = StringBuilder::new(list_of_lists.len());
let mut lb = ListBuilder::new(builder);
for values in list_of_lists {
let builder = lb.values();
@@ -137,9 +142,18 @@ mod tests {
lb.append(true).unwrap();
}
- let expr = col("l", &schema).unwrap();
- let batch = RecordBatch::try_new(Arc::new(schema),
vec![Arc::new(lb.finish())])?;
+ lb.finish()
+ }
+ fn get_indexed_field_test(
+ list_of_lists: Vec<Vec<Option<&str>>>,
+ index: i64,
+ expected: Vec<Option<&str>>,
+ ) -> Result<()> {
+ let schema = list_schema("l");
+ let list_col = build_utf8_lists(list_of_lists);
+ let expr = col("l", &schema).unwrap();
+ let batch = RecordBatch::try_new(Arc::new(schema),
vec![Arc::new(list_col)])?;
let key = ScalarValue::Int64(Some(index));
let expr = Arc::new(GetIndexedFieldExpr::new(expr, key));
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
@@ -222,4 +236,133 @@ mod tests {
let expr = col("l", &schema).unwrap();
get_indexed_field_test_failure(schema, expr,
ScalarValue::Int8(Some(0)), "This feature is not implemented: get indexed field
is only possible on lists with int64 indexes. Tried List(Field { name:
\"item\", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false,
metadata: None }) with 0 index")
}
+
+ fn build_struct(
+ fields: Vec<Field>,
+ list_of_tuples: Vec<(Option<i64>, Vec<Option<&str>>)>,
+ ) -> StructArray {
+ let foo_builder = Int64Array::builder(list_of_tuples.len());
+ let str_builder = StringBuilder::new(list_of_tuples.len());
+ let bar_builder = ListBuilder::new(str_builder);
+ let mut builder = StructBuilder::new(
+ fields,
+ vec![Box::new(foo_builder), Box::new(bar_builder)],
+ );
+ for (int_value, list_value) in list_of_tuples {
+ let fb = builder.field_builder::<Int64Builder>(0).unwrap();
+ match int_value {
+ None => fb.append_null(),
+ Some(v) => fb.append_value(v),
+ }
+ .unwrap();
+ builder.append(true).unwrap();
+ let lb = builder
+ .field_builder::<ListBuilder<StringBuilder>>(1)
+ .unwrap();
+ for str_value in list_value {
+ match str_value {
+ None => lb.values().append_null(),
+ Some(v) => lb.values().append_value(v),
+ }
+ .unwrap();
+ }
+ lb.append(true).unwrap();
+ }
+ builder.finish()
+ }
+
+ fn get_indexed_field_mixed_test(
+ list_of_tuples: Vec<(Option<i64>, Vec<Option<&str>>)>,
+ expected_strings: Vec<Vec<Option<&str>>>,
+ expected_ints: Vec<Option<i64>>,
+ ) -> Result<()> {
+ let struct_col = "s";
+ let fields = vec![
+ Field::new("foo", DataType::Int64, true),
+ Field::new(
+ "bar",
+ DataType::List(Box::new(Field::new("item", DataType::Utf8,
true))),
+ true,
+ ),
+ ];
+ let schema = Schema::new(vec![Field::new(
+ struct_col,
+ DataType::Struct(fields.clone()),
+ true,
+ )]);
+ let struct_col = build_struct(fields, list_of_tuples.clone());
+
+ let struct_col_expr = col("s", &schema).unwrap();
+ let batch = RecordBatch::try_new(Arc::new(schema),
vec![Arc::new(struct_col)])?;
+
+ let int_field_key = ScalarValue::Utf8(Some("foo".to_string()));
+ let get_field_expr = Arc::new(GetIndexedFieldExpr::new(
+ struct_col_expr.clone(),
+ int_field_key,
+ ));
+ let result = get_field_expr
+ .evaluate(&batch)?
+ .into_array(batch.num_rows());
+ let result = result
+ .as_any()
+ .downcast_ref::<Int64Array>()
+ .expect("failed to downcast to Int64Array");
+ let expected = &Int64Array::from(expected_ints);
+ assert_eq!(expected, result);
+
+ let list_field_key = ScalarValue::Utf8(Some("bar".to_string()));
+ let get_list_expr =
+ Arc::new(GetIndexedFieldExpr::new(struct_col_expr,
list_field_key));
+ let result =
get_list_expr.evaluate(&batch)?.into_array(batch.num_rows());
+ let result = result
+ .as_any()
+ .downcast_ref::<ListArray>()
+ .unwrap_or_else(|| panic!("failed to downcast to ListArray :
{:?}", result));
+ let expected =
+ &build_utf8_lists(list_of_tuples.into_iter().map(|t|
t.1).collect());
+ assert_eq!(expected, result);
+
+ for (i, expected) in expected_strings.into_iter().enumerate() {
+ let get_nested_str_expr = Arc::new(GetIndexedFieldExpr::new(
+ get_list_expr.clone(),
+ ScalarValue::Int64(Some(i as i64)),
+ ));
+ let result = get_nested_str_expr
+ .evaluate(&batch)?
+ .into_array(batch.num_rows());
+ let result = result
+ .as_any()
+ .downcast_ref::<StringArray>()
+ .unwrap_or_else(|| {
+ panic!("failed to downcast to StringArray : {:?}", result)
+ });
+ let expected = &StringArray::from(expected);
+ assert_eq!(expected, result);
+ }
+ Ok(())
+ }
+
+ #[test]
+ fn get_indexed_field_struct() -> Result<()> {
+ let list_of_structs = vec![
+ (Some(10), vec![Some("a"), Some("b"), None]),
+ (Some(15), vec![None, Some("c"), Some("d")]),
+ (None, vec![Some("e"), None, Some("f")]),
+ ];
+
+ let expected_list = vec![
+ vec![Some("a"), None, Some("e")],
+ vec![Some("b"), Some("c"), None],
+ vec![None, Some("d"), Some("f")],
+ ];
+
+ let expected_ints = vec![Some(10), Some(15), None];
+
+ get_indexed_field_mixed_test(
+ list_of_structs.clone(),
+ expected_list,
+ expected_ints,
+ )?;
+ Ok(())
+ }
}
diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs
index dd9198c..6cd1d38 100644
--- a/datafusion/tests/sql.rs
+++ b/datafusion/tests/sql.rs
@@ -5476,3 +5476,47 @@ async fn query_nested_get_indexed_field() -> Result<()> {
assert_eq!(expected, actual);
Ok(())
}
+
+#[tokio::test]
+async fn query_nested_get_indexed_field_on_struct() -> Result<()> {
+ let mut ctx = ExecutionContext::new();
+ let nested_dt = DataType::List(Box::new(Field::new("item",
DataType::Int64, true)));
+ // Nested schema of { "some_struct": { "bar": [i64] } }
+ let struct_fields = vec![Field::new("bar", nested_dt.clone(), true)];
+ let schema = Arc::new(Schema::new(vec![Field::new(
+ "some_struct",
+ DataType::Struct(struct_fields.clone()),
+ false,
+ )]));
+
+ let builder = PrimitiveBuilder::<Int64Type>::new(3);
+ let nested_lb = ListBuilder::new(builder);
+ let mut sb = StructBuilder::new(struct_fields, vec![Box::new(nested_lb)]);
+ for int_vec in vec![vec![0, 1, 2, 3], vec![4, 5, 6, 7], vec![8, 9, 10,
11]] {
+ let lb = sb.field_builder::<ListBuilder<Int64Builder>>(0).unwrap();
+ for int in int_vec {
+ lb.values().append_value(int).unwrap();
+ }
+ lb.append(true).unwrap();
+ }
+ let data = RecordBatch::try_new(schema.clone(),
vec![Arc::new(sb.finish())])?;
+ let table = MemTable::try_new(schema, vec![vec![data]])?;
+ let table_a = Arc::new(table);
+
+ ctx.register_table("structs", table_a)?;
+
+ // Original column is micros, convert to millis and check timestamp
+ let sql = "SELECT some_struct[\"bar\"] as l0 FROM structs LIMIT 3";
+ let actual = execute(&mut ctx, sql).await;
+ let expected = vec![
+ vec!["[0, 1, 2, 3]"],
+ vec!["[4, 5, 6, 7]"],
+ vec!["[8, 9, 10, 11]"],
+ ];
+ assert_eq!(expected, actual);
+ let sql = "SELECT some_struct[\"bar\"][0] as i0 FROM structs LIMIT 3";
+ let actual = execute(&mut ctx, sql).await;
+ let expected = vec![vec!["0"], vec!["4"], vec!["8"]];
+ assert_eq!(expected, actual);
+ Ok(())
+}