This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 6a7dbbb  Improve GetIndexedFieldExpr adding utf8 key based access for 
struct v… (#1204)
6a7dbbb is described below

commit 6a7dbbb848880e4bd8a013b85aa38780819622ec
Author: Guillaume Balaine <[email protected]>
AuthorDate: Tue Nov 2 21:08:38 2021 +0100

    Improve GetIndexedFieldExpr adding utf8 key based access for struct v… 
(#1204)
    
    * Improve GetIndexedFieldExpr adding utf8 key based access for struct values
    
    * fix clippies
---
 datafusion/src/field_util.rs                       |  21 ++-
 datafusion/src/logical_plan/expr.rs                |   2 +-
 .../physical_plan/expressions/get_indexed_field.rs | 167 +++++++++++++++++++--
 datafusion/tests/sql.rs                            |  44 ++++++
 4 files changed, 220 insertions(+), 14 deletions(-)

diff --git a/datafusion/src/field_util.rs b/datafusion/src/field_util.rs
index 9d5face..272c17b 100644
--- a/datafusion/src/field_util.rs
+++ b/datafusion/src/field_util.rs
@@ -22,7 +22,7 @@ use arrow::datatypes::{DataType, Field};
 use crate::error::{DataFusionError, Result};
 use crate::scalar::ScalarValue;
 
-/// Returns the field access indexed by `key` from a [`DataType::List`]
+/// Returns the field access indexed by `key` from a [`DataType::List`] or 
[`DataType::Struct`]
 /// # Error
 /// Errors if
 /// * the `data_type` is not a Struct or,
@@ -39,6 +39,25 @@ pub fn get_indexed_field(data_type: &DataType, key: 
&ScalarValue) -> Result<Fiel
                 Ok(Field::new(&i.to_string(), lt.data_type().clone(), false))
             }
         }
+        (DataType::Struct(fields), ScalarValue::Utf8(Some(s))) => {
+            if s.is_empty() {
+                Err(DataFusionError::Plan(
+                    "Struct based indexed access requires a non empty 
string".to_string(),
+                ))
+            } else {
+                let field = fields.iter().find(|f| f.name() == s);
+                match field {
+                    None => Err(DataFusionError::Plan(format!(
+                        "Field {} not found in struct",
+                        s
+                    ))),
+                    Some(f) => Ok(f.clone()),
+                }
+            }
+        }
+        (DataType::Struct(_), _) => Err(DataFusionError::Plan(
+            "Only utf8 strings are valid as an indexed field in a 
struct".to_string(),
+        )),
         (DataType::List(_), _) => Err(DataFusionError::Plan(
             "Only ints are valid as an indexed field in a list".to_string(),
         )),
diff --git a/datafusion/src/logical_plan/expr.rs 
b/datafusion/src/logical_plan/expr.rs
index 19e6fe3..318d73f 100644
--- a/datafusion/src/logical_plan/expr.rs
+++ b/datafusion/src/logical_plan/expr.rs
@@ -246,7 +246,7 @@ pub enum Expr {
     IsNull(Box<Expr>),
     /// arithmetic negation of an expression, the operand must be of a signed 
numeric data type
     Negative(Box<Expr>),
-    /// Returns the field of a [`ListArray`] by key
+    /// Returns the field of a [`ListArray`] or [`StructArray`] by key
     GetIndexedField {
         /// the expression to take the field from
         expr: Box<Expr>,
diff --git a/datafusion/src/physical_plan/expressions/get_indexed_field.rs 
b/datafusion/src/physical_plan/expressions/get_indexed_field.rs
index 8a9191e..7e60698 100644
--- a/datafusion/src/physical_plan/expressions/get_indexed_field.rs
+++ b/datafusion/src/physical_plan/expressions/get_indexed_field.rs
@@ -34,7 +34,7 @@ use crate::{
     field_util::get_indexed_field as get_data_type_field,
     physical_plan::{ColumnarValue, PhysicalExpr},
 };
-use arrow::array::ListArray;
+use arrow::array::{ListArray, StructArray};
 use std::fmt::Debug;
 
 /// expression to get a field of a struct array.
@@ -81,7 +81,7 @@ impl PhysicalExpr for GetIndexedFieldExpr {
         let arg = self.arg.evaluate(batch)?;
         match arg {
             ColumnarValue::Array(array) => match (array.data_type(), 
&self.key) {
-                (DataType::List(_), _) if self.key.is_null() => {
+                (DataType::List(_) | DataType::Struct(_), _) if 
self.key.is_null() => {
                     let scalar_null: ScalarValue = 
array.data_type().try_into()?;
                     Ok(ColumnarValue::Scalar(scalar_null))
                 }
@@ -100,6 +100,13 @@ impl PhysicalExpr for GetIndexedFieldExpr {
                     let iter = concat(vec.as_slice()).unwrap();
                     Ok(ColumnarValue::Array(iter))
                 }
+                (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => {
+                    let as_struct_array = 
array.as_any().downcast_ref::<StructArray>().unwrap();
+                    match as_struct_array.column_by_name(k) {
+                        None => Err(DataFusionError::Execution(format!("get 
indexed field {} not found in struct", k))),
+                        Some(col) => Ok(ColumnarValue::Array(col.clone()))
+                    }
+                }
                 (dt, key) => Err(DataFusionError::NotImplemented(format!("get 
indexed field is only possible on lists with int64 indexes. Tried {} with {} 
index", dt, key))),
             },
             ColumnarValue::Scalar(_) => Err(DataFusionError::NotImplemented(
@@ -112,18 +119,16 @@ impl PhysicalExpr for GetIndexedFieldExpr {
 #[cfg(test)]
 mod tests {
     use super::*;
+    use crate::arrow::array::GenericListArray;
     use crate::error::Result;
     use crate::physical_plan::expressions::{col, lit};
-    use arrow::array::{ListBuilder, StringBuilder};
+    use arrow::array::{
+        Int64Array, Int64Builder, ListBuilder, StringBuilder, StructArray, 
StructBuilder,
+    };
     use arrow::{array::StringArray, datatypes::Field};
 
-    fn get_indexed_field_test(
-        list_of_lists: Vec<Vec<Option<&str>>>,
-        index: i64,
-        expected: Vec<Option<&str>>,
-    ) -> Result<()> {
-        let schema = list_schema("l");
-        let builder = StringBuilder::new(3);
+    fn build_utf8_lists(list_of_lists: Vec<Vec<Option<&str>>>) -> 
GenericListArray<i32> {
+        let builder = StringBuilder::new(list_of_lists.len());
         let mut lb = ListBuilder::new(builder);
         for values in list_of_lists {
             let builder = lb.values();
@@ -137,9 +142,18 @@ mod tests {
             lb.append(true).unwrap();
         }
 
-        let expr = col("l", &schema).unwrap();
-        let batch = RecordBatch::try_new(Arc::new(schema), 
vec![Arc::new(lb.finish())])?;
+        lb.finish()
+    }
 
+    fn get_indexed_field_test(
+        list_of_lists: Vec<Vec<Option<&str>>>,
+        index: i64,
+        expected: Vec<Option<&str>>,
+    ) -> Result<()> {
+        let schema = list_schema("l");
+        let list_col = build_utf8_lists(list_of_lists);
+        let expr = col("l", &schema).unwrap();
+        let batch = RecordBatch::try_new(Arc::new(schema), 
vec![Arc::new(list_col)])?;
         let key = ScalarValue::Int64(Some(index));
         let expr = Arc::new(GetIndexedFieldExpr::new(expr, key));
         let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
@@ -222,4 +236,133 @@ mod tests {
         let expr = col("l", &schema).unwrap();
         get_indexed_field_test_failure(schema, expr,  
ScalarValue::Int8(Some(0)), "This feature is not implemented: get indexed field 
is only possible on lists with int64 indexes. Tried List(Field { name: 
\"item\", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, 
metadata: None }) with 0 index")
     }
+
+    fn build_struct(
+        fields: Vec<Field>,
+        list_of_tuples: Vec<(Option<i64>, Vec<Option<&str>>)>,
+    ) -> StructArray {
+        let foo_builder = Int64Array::builder(list_of_tuples.len());
+        let str_builder = StringBuilder::new(list_of_tuples.len());
+        let bar_builder = ListBuilder::new(str_builder);
+        let mut builder = StructBuilder::new(
+            fields,
+            vec![Box::new(foo_builder), Box::new(bar_builder)],
+        );
+        for (int_value, list_value) in list_of_tuples {
+            let fb = builder.field_builder::<Int64Builder>(0).unwrap();
+            match int_value {
+                None => fb.append_null(),
+                Some(v) => fb.append_value(v),
+            }
+            .unwrap();
+            builder.append(true).unwrap();
+            let lb = builder
+                .field_builder::<ListBuilder<StringBuilder>>(1)
+                .unwrap();
+            for str_value in list_value {
+                match str_value {
+                    None => lb.values().append_null(),
+                    Some(v) => lb.values().append_value(v),
+                }
+                .unwrap();
+            }
+            lb.append(true).unwrap();
+        }
+        builder.finish()
+    }
+
+    fn get_indexed_field_mixed_test(
+        list_of_tuples: Vec<(Option<i64>, Vec<Option<&str>>)>,
+        expected_strings: Vec<Vec<Option<&str>>>,
+        expected_ints: Vec<Option<i64>>,
+    ) -> Result<()> {
+        let struct_col = "s";
+        let fields = vec![
+            Field::new("foo", DataType::Int64, true),
+            Field::new(
+                "bar",
+                DataType::List(Box::new(Field::new("item", DataType::Utf8, 
true))),
+                true,
+            ),
+        ];
+        let schema = Schema::new(vec![Field::new(
+            struct_col,
+            DataType::Struct(fields.clone()),
+            true,
+        )]);
+        let struct_col = build_struct(fields, list_of_tuples.clone());
+
+        let struct_col_expr = col("s", &schema).unwrap();
+        let batch = RecordBatch::try_new(Arc::new(schema), 
vec![Arc::new(struct_col)])?;
+
+        let int_field_key = ScalarValue::Utf8(Some("foo".to_string()));
+        let get_field_expr = Arc::new(GetIndexedFieldExpr::new(
+            struct_col_expr.clone(),
+            int_field_key,
+        ));
+        let result = get_field_expr
+            .evaluate(&batch)?
+            .into_array(batch.num_rows());
+        let result = result
+            .as_any()
+            .downcast_ref::<Int64Array>()
+            .expect("failed to downcast to Int64Array");
+        let expected = &Int64Array::from(expected_ints);
+        assert_eq!(expected, result);
+
+        let list_field_key = ScalarValue::Utf8(Some("bar".to_string()));
+        let get_list_expr =
+            Arc::new(GetIndexedFieldExpr::new(struct_col_expr, 
list_field_key));
+        let result = 
get_list_expr.evaluate(&batch)?.into_array(batch.num_rows());
+        let result = result
+            .as_any()
+            .downcast_ref::<ListArray>()
+            .unwrap_or_else(|| panic!("failed to downcast to ListArray : 
{:?}", result));
+        let expected =
+            &build_utf8_lists(list_of_tuples.into_iter().map(|t| 
t.1).collect());
+        assert_eq!(expected, result);
+
+        for (i, expected) in expected_strings.into_iter().enumerate() {
+            let get_nested_str_expr = Arc::new(GetIndexedFieldExpr::new(
+                get_list_expr.clone(),
+                ScalarValue::Int64(Some(i as i64)),
+            ));
+            let result = get_nested_str_expr
+                .evaluate(&batch)?
+                .into_array(batch.num_rows());
+            let result = result
+                .as_any()
+                .downcast_ref::<StringArray>()
+                .unwrap_or_else(|| {
+                    panic!("failed to downcast to StringArray : {:?}", result)
+                });
+            let expected = &StringArray::from(expected);
+            assert_eq!(expected, result);
+        }
+        Ok(())
+    }
+
+    #[test]
+    fn get_indexed_field_struct() -> Result<()> {
+        let list_of_structs = vec![
+            (Some(10), vec![Some("a"), Some("b"), None]),
+            (Some(15), vec![None, Some("c"), Some("d")]),
+            (None, vec![Some("e"), None, Some("f")]),
+        ];
+
+        let expected_list = vec![
+            vec![Some("a"), None, Some("e")],
+            vec![Some("b"), Some("c"), None],
+            vec![None, Some("d"), Some("f")],
+        ];
+
+        let expected_ints = vec![Some(10), Some(15), None];
+
+        get_indexed_field_mixed_test(
+            list_of_structs.clone(),
+            expected_list,
+            expected_ints,
+        )?;
+        Ok(())
+    }
 }
diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs
index dd9198c..6cd1d38 100644
--- a/datafusion/tests/sql.rs
+++ b/datafusion/tests/sql.rs
@@ -5476,3 +5476,47 @@ async fn query_nested_get_indexed_field() -> Result<()> {
     assert_eq!(expected, actual);
     Ok(())
 }
+
+#[tokio::test]
+async fn query_nested_get_indexed_field_on_struct() -> Result<()> {
+    let mut ctx = ExecutionContext::new();
+    let nested_dt = DataType::List(Box::new(Field::new("item", 
DataType::Int64, true)));
+    // Nested schema of { "some_struct": { "bar": [i64] } }
+    let struct_fields = vec![Field::new("bar", nested_dt.clone(), true)];
+    let schema = Arc::new(Schema::new(vec![Field::new(
+        "some_struct",
+        DataType::Struct(struct_fields.clone()),
+        false,
+    )]));
+
+    let builder = PrimitiveBuilder::<Int64Type>::new(3);
+    let nested_lb = ListBuilder::new(builder);
+    let mut sb = StructBuilder::new(struct_fields, vec![Box::new(nested_lb)]);
+    for int_vec in vec![vec![0, 1, 2, 3], vec![4, 5, 6, 7], vec![8, 9, 10, 
11]] {
+        let lb = sb.field_builder::<ListBuilder<Int64Builder>>(0).unwrap();
+        for int in int_vec {
+            lb.values().append_value(int).unwrap();
+        }
+        lb.append(true).unwrap();
+    }
+    let data = RecordBatch::try_new(schema.clone(), 
vec![Arc::new(sb.finish())])?;
+    let table = MemTable::try_new(schema, vec![vec![data]])?;
+    let table_a = Arc::new(table);
+
+    ctx.register_table("structs", table_a)?;
+
+    // Original column is micros, convert to millis and check timestamp
+    let sql = "SELECT some_struct[\"bar\"] as l0 FROM structs LIMIT 3";
+    let actual = execute(&mut ctx, sql).await;
+    let expected = vec![
+        vec!["[0, 1, 2, 3]"],
+        vec!["[4, 5, 6, 7]"],
+        vec!["[8, 9, 10, 11]"],
+    ];
+    assert_eq!(expected, actual);
+    let sql = "SELECT some_struct[\"bar\"][0] as i0 FROM structs LIMIT 3";
+    let actual = execute(&mut ctx, sql).await;
+    let expected = vec![vec!["0"], vec!["4"], vec!["8"]];
+    assert_eq!(expected, actual);
+    Ok(())
+}

Reply via email to