This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new db4f9984e Fix nested list indexing when the index is 0 or larger than 
the list size (#5311)
db4f9984e is described below

commit db4f9984eb18a495b1ce48765fdf651343e1b10f
Author: Ahmed Riza <[email protected]>
AuthorDate: Sun Feb 19 15:41:49 2023 +0000

    Fix nested list indexing when the index is 0 or larger than the list size 
(#5311)
    
    * [Bugfix] When indexing a nested list, if an invalid index such as
    0 or an index larger than the size of the list is given, this
    will throw an error instead of correctly returning a null.
    
    This PR fixes that so that indexing works uniformly.
    
    * Address review comments on #5311
---
 .../src/expressions/get_indexed_field.rs           | 104 +++++++++++++++++++--
 1 file changed, 94 insertions(+), 10 deletions(-)

diff --git a/datafusion/physical-expr/src/expressions/get_indexed_field.rs 
b/datafusion/physical-expr/src/expressions/get_indexed_field.rs
index 9665a8272..b9edc6029 100644
--- a/datafusion/physical-expr/src/expressions/get_indexed_field.rs
+++ b/datafusion/physical-expr/src/expressions/get_indexed_field.rs
@@ -89,11 +89,11 @@ impl PhysicalExpr for GetIndexedFieldExpr {
                 let scalar_null: ScalarValue = array.data_type().try_into()?;
                 Ok(ColumnarValue::Scalar(scalar_null))
             }
-            (DataType::List(_), ScalarValue::Int64(Some(i))) => {
+            (DataType::List(lst), ScalarValue::Int64(Some(i))) => {
                 let as_list_array = as_list_array(&array)?;
 
                 if *i < 1 || as_list_array.is_empty() {
-                    let scalar_null: ScalarValue = 
array.data_type().try_into()?;
+                    let scalar_null: ScalarValue = lst.data_type().try_into()?;
                     return Ok(ColumnarValue::Scalar(scalar_null))
                 }
 
@@ -111,7 +111,7 @@ impl PhysicalExpr for GetIndexedFieldExpr {
 
                 // concat requires input of at least one array
                 if sliced_array.is_empty() {
-                    let scalar_null: ScalarValue = 
array.data_type().try_into()?;
+                    let scalar_null: ScalarValue = lst.data_type().try_into()?;
                     Ok(ColumnarValue::Scalar(scalar_null))
                 } else {
                     let vec = sliced_array.iter().map(|a| 
a.as_ref()).collect::<Vec<&dyn Array>>();
@@ -123,13 +123,20 @@ impl PhysicalExpr for GetIndexedFieldExpr {
             (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => {
                 let as_struct_array = as_struct_array(&array)?;
                 match as_struct_array.column_by_name(k) {
-                    None => Err(DataFusionError::Execution(format!("get 
indexed field {k} not found in struct"))),
+                    None => Err(DataFusionError::Execution(
+                        format!("get indexed field {k} not found in struct"))),
                     Some(col) => Ok(ColumnarValue::Array(col.clone()))
                 }
             }
-            (DataType::List(_), key) => 
Err(DataFusionError::Execution(format!("get indexed field is only possible on 
lists with int64 indexes. Tried with {key:?} index"))),
-            (DataType::Struct(_), key) => 
Err(DataFusionError::Execution(format!("get indexed field is only possible on 
struct with utf8 indexes. Tried with {key:?} index"))),
-            (dt, key) => Err(DataFusionError::Execution(format!("get indexed 
field is only possible on lists with int64 indexes or struct with utf8 indexes. 
Tried {dt:?} with {key:?} index"))),
+            (DataType::List(_), key) => Err(DataFusionError::Execution(
+                format!("get indexed field is only possible on lists with 
int64 indexes. \
+                         Tried with {key:?} index"))),
+            (DataType::Struct(_), key) => Err(DataFusionError::Execution(
+                format!("get indexed field is only possible on struct with 
utf8 indexes. \
+                         Tried with {key:?} index"))),
+            (dt, key) => Err(DataFusionError::Execution(
+                format!("get indexed field is only possible on lists with 
int64 indexes or struct \
+                         with utf8 indexes. Tried {dt:?} with {key:?} 
index"))),
         }
     }
 
@@ -161,10 +168,11 @@ impl PartialEq<dyn Any> for GetIndexedFieldExpr {
 mod tests {
     use super::*;
     use crate::expressions::{col, lit};
-    use arrow::array::GenericListArray;
+    use arrow::array::{ArrayRef, Float64Array, GenericListArray, 
PrimitiveBuilder};
     use arrow::array::{
         Int64Array, Int64Builder, ListBuilder, StringBuilder, StructArray, 
StructBuilder,
     };
+    use arrow::datatypes::{Float64Type, Int64Type};
     use arrow::{array::StringArray, datatypes::Field};
     use datafusion_common::cast::{as_int64_array, as_string_array};
     use datafusion_common::Result;
@@ -265,14 +273,20 @@ mod tests {
     fn get_indexed_field_invalid_scalar() -> Result<()> {
         let schema = list_schema("l");
         let expr = lit("a");
-        get_indexed_field_test_failure(schema, expr,  
ScalarValue::Int64(Some(0)), "Execution error: get indexed field is only 
possible on lists with int64 indexes or struct with utf8 indexes. Tried Utf8 
with Int64(0) index")
+        get_indexed_field_test_failure(
+            schema, expr,  ScalarValue::Int64(Some(0)),
+            "Execution error: get indexed field is only possible on lists with 
int64 indexes or \
+             struct with utf8 indexes. Tried Utf8 with Int64(0) index")
     }
 
     #[test]
     fn get_indexed_field_invalid_list_index() -> Result<()> {
         let schema = list_schema("l");
         let expr = col("l", &schema).unwrap();
-        get_indexed_field_test_failure(schema, expr,  
ScalarValue::Int8(Some(0)), "Execution error: get indexed field is only 
possible on lists with int64 indexes. Tried with Int8(0) index")
+        get_indexed_field_test_failure(
+            schema, expr,  ScalarValue::Int8(Some(0)),
+            "Execution error: get indexed field is only possible on lists with 
int64 indexes. \
+             Tried with Int8(0) index")
     }
 
     fn build_struct(
@@ -390,4 +404,74 @@ mod tests {
         )?;
         Ok(())
     }
+
+    #[test]
+    fn get_indexed_field_list_out_of_bounds() {
+        let fields = vec![
+            Field::new("id", DataType::Int64, true),
+            Field::new(
+                "a",
+                DataType::List(Box::new(Field::new("item", DataType::Float64, 
true))),
+                true,
+            ),
+        ];
+
+        let schema = Schema::new(fields);
+        let mut int_builder = PrimitiveBuilder::<Int64Type>::new();
+        int_builder.append_value(1);
+
+        let mut lb = ListBuilder::new(PrimitiveBuilder::<Float64Type>::new());
+        lb.values().append_value(1.0);
+        lb.values().append_null();
+        lb.values().append_value(3.0);
+        lb.append(true);
+
+        let batch = RecordBatch::try_new(
+            Arc::new(schema.clone()),
+            vec![Arc::new(int_builder.finish()), Arc::new(lb.finish())],
+        )
+        .unwrap();
+
+        let col_a = col("a", &schema).unwrap();
+        // out of bounds index
+        verify_index_evaluation(&batch, col_a.clone(), 0, float64_array(None));
+
+        verify_index_evaluation(&batch, col_a.clone(), 1, 
float64_array(Some(1.0)));
+        verify_index_evaluation(&batch, col_a.clone(), 2, float64_array(None));
+        verify_index_evaluation(&batch, col_a.clone(), 3, 
float64_array(Some(3.0)));
+
+        // out of bounds index
+        verify_index_evaluation(&batch, col_a.clone(), 100, 
float64_array(None));
+    }
+
+    fn verify_index_evaluation(
+        batch: &RecordBatch,
+        arg: Arc<dyn PhysicalExpr>,
+        index: i64,
+        expected_result: ArrayRef,
+    ) {
+        let expr = Arc::new(GetIndexedFieldExpr::new(
+            arg,
+            ScalarValue::Int64(Some(index)),
+        ));
+        let result = 
expr.evaluate(batch).unwrap().into_array(batch.num_rows());
+        assert!(
+            result == expected_result.clone(),
+            "result: {:?} != expected result: {:?}",
+            result,
+            expected_result
+        );
+        assert_eq!(result.data_type(), &DataType::Float64);
+    }
+
+    fn float64_array(value: Option<f64>) -> ArrayRef {
+        match value {
+            Some(v) => Arc::new(Float64Array::from_value(v, 1)),
+            None => {
+                let mut b = PrimitiveBuilder::<Float64Type>::new();
+                b.append_null();
+                Arc::new(b.finish())
+            }
+        }
+    }
 }

Reply via email to