This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new db4f9984e Fix nested list indexing when the index is 0 or larger than
the list size (#5311)
db4f9984e is described below
commit db4f9984eb18a495b1ce48765fdf651343e1b10f
Author: Ahmed Riza <[email protected]>
AuthorDate: Sun Feb 19 15:41:49 2023 +0000
Fix nested list indexing when the index is 0 or larger than the list size
(#5311)
* [Bugfix] When indexing a nested list, if an invalid index such as
0 or an index larger than the size of the list is given, this
will throw an error instead of correctly returning a null.
This PR fixes that so that indexing works uniformly.
* Address review comments on #5311
---
.../src/expressions/get_indexed_field.rs | 104 +++++++++++++++++++--
1 file changed, 94 insertions(+), 10 deletions(-)
diff --git a/datafusion/physical-expr/src/expressions/get_indexed_field.rs
b/datafusion/physical-expr/src/expressions/get_indexed_field.rs
index 9665a8272..b9edc6029 100644
--- a/datafusion/physical-expr/src/expressions/get_indexed_field.rs
+++ b/datafusion/physical-expr/src/expressions/get_indexed_field.rs
@@ -89,11 +89,11 @@ impl PhysicalExpr for GetIndexedFieldExpr {
let scalar_null: ScalarValue = array.data_type().try_into()?;
Ok(ColumnarValue::Scalar(scalar_null))
}
- (DataType::List(_), ScalarValue::Int64(Some(i))) => {
+ (DataType::List(lst), ScalarValue::Int64(Some(i))) => {
let as_list_array = as_list_array(&array)?;
if *i < 1 || as_list_array.is_empty() {
- let scalar_null: ScalarValue =
array.data_type().try_into()?;
+ let scalar_null: ScalarValue = lst.data_type().try_into()?;
return Ok(ColumnarValue::Scalar(scalar_null))
}
@@ -111,7 +111,7 @@ impl PhysicalExpr for GetIndexedFieldExpr {
// concat requires input of at least one array
if sliced_array.is_empty() {
- let scalar_null: ScalarValue =
array.data_type().try_into()?;
+ let scalar_null: ScalarValue = lst.data_type().try_into()?;
Ok(ColumnarValue::Scalar(scalar_null))
} else {
let vec = sliced_array.iter().map(|a|
a.as_ref()).collect::<Vec<&dyn Array>>();
@@ -123,13 +123,20 @@ impl PhysicalExpr for GetIndexedFieldExpr {
(DataType::Struct(_), ScalarValue::Utf8(Some(k))) => {
let as_struct_array = as_struct_array(&array)?;
match as_struct_array.column_by_name(k) {
- None => Err(DataFusionError::Execution(format!("get
indexed field {k} not found in struct"))),
+ None => Err(DataFusionError::Execution(
+ format!("get indexed field {k} not found in struct"))),
Some(col) => Ok(ColumnarValue::Array(col.clone()))
}
}
- (DataType::List(_), key) =>
Err(DataFusionError::Execution(format!("get indexed field is only possible on
lists with int64 indexes. Tried with {key:?} index"))),
- (DataType::Struct(_), key) =>
Err(DataFusionError::Execution(format!("get indexed field is only possible on
struct with utf8 indexes. Tried with {key:?} index"))),
- (dt, key) => Err(DataFusionError::Execution(format!("get indexed
field is only possible on lists with int64 indexes or struct with utf8 indexes.
Tried {dt:?} with {key:?} index"))),
+ (DataType::List(_), key) => Err(DataFusionError::Execution(
+ format!("get indexed field is only possible on lists with
int64 indexes. \
+ Tried with {key:?} index"))),
+ (DataType::Struct(_), key) => Err(DataFusionError::Execution(
+ format!("get indexed field is only possible on struct with
utf8 indexes. \
+ Tried with {key:?} index"))),
+ (dt, key) => Err(DataFusionError::Execution(
+ format!("get indexed field is only possible on lists with
int64 indexes or struct \
+ with utf8 indexes. Tried {dt:?} with {key:?}
index"))),
}
}
@@ -161,10 +168,11 @@ impl PartialEq<dyn Any> for GetIndexedFieldExpr {
mod tests {
use super::*;
use crate::expressions::{col, lit};
- use arrow::array::GenericListArray;
+ use arrow::array::{ArrayRef, Float64Array, GenericListArray,
PrimitiveBuilder};
use arrow::array::{
Int64Array, Int64Builder, ListBuilder, StringBuilder, StructArray,
StructBuilder,
};
+ use arrow::datatypes::{Float64Type, Int64Type};
use arrow::{array::StringArray, datatypes::Field};
use datafusion_common::cast::{as_int64_array, as_string_array};
use datafusion_common::Result;
@@ -265,14 +273,20 @@ mod tests {
fn get_indexed_field_invalid_scalar() -> Result<()> {
let schema = list_schema("l");
let expr = lit("a");
- get_indexed_field_test_failure(schema, expr,
ScalarValue::Int64(Some(0)), "Execution error: get indexed field is only
possible on lists with int64 indexes or struct with utf8 indexes. Tried Utf8
with Int64(0) index")
+ get_indexed_field_test_failure(
+ schema, expr, ScalarValue::Int64(Some(0)),
+ "Execution error: get indexed field is only possible on lists with
int64 indexes or \
+ struct with utf8 indexes. Tried Utf8 with Int64(0) index")
}
#[test]
fn get_indexed_field_invalid_list_index() -> Result<()> {
let schema = list_schema("l");
let expr = col("l", &schema).unwrap();
- get_indexed_field_test_failure(schema, expr,
ScalarValue::Int8(Some(0)), "Execution error: get indexed field is only
possible on lists with int64 indexes. Tried with Int8(0) index")
+ get_indexed_field_test_failure(
+ schema, expr, ScalarValue::Int8(Some(0)),
+ "Execution error: get indexed field is only possible on lists with
int64 indexes. \
+ Tried with Int8(0) index")
}
fn build_struct(
@@ -390,4 +404,74 @@ mod tests {
)?;
Ok(())
}
+
+ #[test]
+ fn get_indexed_field_list_out_of_bounds() {
+ let fields = vec![
+ Field::new("id", DataType::Int64, true),
+ Field::new(
+ "a",
+ DataType::List(Box::new(Field::new("item", DataType::Float64,
true))),
+ true,
+ ),
+ ];
+
+ let schema = Schema::new(fields);
+ let mut int_builder = PrimitiveBuilder::<Int64Type>::new();
+ int_builder.append_value(1);
+
+ let mut lb = ListBuilder::new(PrimitiveBuilder::<Float64Type>::new());
+ lb.values().append_value(1.0);
+ lb.values().append_null();
+ lb.values().append_value(3.0);
+ lb.append(true);
+
+ let batch = RecordBatch::try_new(
+ Arc::new(schema.clone()),
+ vec![Arc::new(int_builder.finish()), Arc::new(lb.finish())],
+ )
+ .unwrap();
+
+ let col_a = col("a", &schema).unwrap();
+ // out of bounds index
+ verify_index_evaluation(&batch, col_a.clone(), 0, float64_array(None));
+
+ verify_index_evaluation(&batch, col_a.clone(), 1,
float64_array(Some(1.0)));
+ verify_index_evaluation(&batch, col_a.clone(), 2, float64_array(None));
+ verify_index_evaluation(&batch, col_a.clone(), 3,
float64_array(Some(3.0)));
+
+ // out of bounds index
+ verify_index_evaluation(&batch, col_a.clone(), 100,
float64_array(None));
+ }
+
+ fn verify_index_evaluation(
+ batch: &RecordBatch,
+ arg: Arc<dyn PhysicalExpr>,
+ index: i64,
+ expected_result: ArrayRef,
+ ) {
+ let expr = Arc::new(GetIndexedFieldExpr::new(
+ arg,
+ ScalarValue::Int64(Some(index)),
+ ));
+ let result =
expr.evaluate(batch).unwrap().into_array(batch.num_rows());
+ assert!(
+ result == expected_result.clone(),
+ "result: {:?} != expected result: {:?}",
+ result,
+ expected_result
+ );
+ assert_eq!(result.data_type(), &DataType::Float64);
+ }
+
+ fn float64_array(value: Option<f64>) -> ArrayRef {
+ match value {
+ Some(v) => Arc::new(Float64Array::from_value(v, 1)),
+ None => {
+ let mut b = PrimitiveBuilder::<Float64Type>::new();
+ b.append_null();
+ Arc::new(b.finish())
+ }
+ }
+ }
}