This is an automated email from the ASF dual-hosted git repository. alamb pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push: new 310eab006b Map access supports constant-resolvable expressions (#14712) 310eab006b is described below commit 310eab006b26885ec558cf5c4572c73a7d824ee9 Author: Lordworms <48054792+lordwo...@users.noreply.github.com> AuthorDate: Thu Feb 20 04:00:24 2025 -0800 Map access supports constant-resolvable expressions (#14712) * Map access supports constant-resolvable expressions * adding tests fix clippy fix clippy fix clippy * fix clippy --- datafusion/functions-nested/src/planner.rs | 22 ++++-- datafusion/functions/src/core/getfield.rs | 111 ++++++++++++++++++----------- datafusion/sqllogictest/test_files/map.slt | 62 ++++++++++++++++ 3 files changed, 148 insertions(+), 47 deletions(-) diff --git a/datafusion/functions-nested/src/planner.rs b/datafusion/functions-nested/src/planner.rs index d55176a42c..369eaecb19 100644 --- a/datafusion/functions-nested/src/planner.rs +++ b/datafusion/functions-nested/src/planner.rs @@ -17,17 +17,20 @@ //! SQL planning extensions like [`NestedFunctionPlanner`] and [`FieldAccessPlanner`] -use std::sync::Arc; - +use arrow::datatypes::DataType; +use datafusion_common::ExprSchema; use datafusion_common::{plan_err, utils::list_ndims, DFSchema, Result}; -use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams, ScalarFunction}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams}; use datafusion_expr::AggregateUDF; use datafusion_expr::{ planner::{ExprPlanner, PlannerResult, RawBinaryExpr, RawFieldAccessExpr}, sqlparser, Expr, ExprSchemable, GetFieldAccess, }; +use datafusion_functions::core::get_field as get_field_inner; use datafusion_functions::expr_fn::get_field; use datafusion_functions_aggregate::nth_value::nth_value_udaf; +use std::sync::Arc; use crate::map::map_udf; use crate::{ @@ -140,7 +143,7 @@ impl ExprPlanner for FieldAccessPlanner { fn plan_field_access( &self, expr: RawFieldAccessExpr, - _schema: &DFSchema, + schema: &DFSchema, ) -> Result<PlannerResult<RawFieldAccessExpr>> { let RawFieldAccessExpr { expr, field_access } = expr; @@ -173,6 +176,17 @@ impl ExprPlanner for FieldAccessPlanner { null_treatment, )), )), + // special case for map access with + Expr::Column(ref c) + if matches!(schema.data_type(c)?, DataType::Map(_, _)) => + { + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf( + get_field_inner(), + vec![expr, *index], + ), + ))) + } _ => Ok(PlannerResult::Planned(array_element(expr, *index))), } } diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index d667d0d8c1..d900ee5825 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -16,9 +16,12 @@ // under the License. use arrow::array::{ - make_array, Array, Capacities, MutableArrayData, Scalar, StringArray, + make_array, make_comparator, Array, BooleanArray, Capacities, MutableArrayData, + Scalar, }; +use arrow::compute::SortOptions; use arrow::datatypes::DataType; +use arrow_buffer::NullBuffer; use datafusion_common::cast::{as_map_array, as_struct_array}; use datafusion_common::{ exec_err, internal_err, plan_datafusion_err, utils::take_function_args, Result, @@ -106,11 +109,7 @@ impl ScalarUDFImpl for GetFieldFunc { let name = match field_name { Expr::Literal(name) => name, - _ => { - return exec_err!( - "get_field function requires the argument field_name to be a string" - ); - } + other => &ScalarValue::Utf8(Some(other.schema_name().to_string())), }; Ok(format!("{base}[{name}]")) @@ -118,14 +117,9 @@ impl ScalarUDFImpl for GetFieldFunc { fn schema_name(&self, args: &[Expr]) -> Result<String> { let [base, field_name] = take_function_args(self.name(), args)?; - let name = match field_name { Expr::Literal(name) => name, - _ => { - return exec_err!( - "get_field function requires the argument field_name to be a string" - ); - } + other => &ScalarValue::Utf8(Some(other.schema_name().to_string())), }; Ok(format!("{}[{}]", base.schema_name(), name)) @@ -182,7 +176,6 @@ impl ScalarUDFImpl for GetFieldFunc { let arrays = ColumnarValue::values_to_arrays(&[base.clone(), field_name.clone()])?; let array = Arc::clone(&arrays[0]); - let name = match field_name { ColumnarValue::Scalar(name) => name, _ => { @@ -192,38 +185,70 @@ impl ScalarUDFImpl for GetFieldFunc { } }; + fn process_map_array( + array: Arc<dyn Array>, + key_array: Arc<dyn Array>, + ) -> Result<ColumnarValue> { + let map_array = as_map_array(array.as_ref())?; + let keys = if key_array.data_type().is_nested() { + let comparator = make_comparator( + map_array.keys().as_ref(), + key_array.as_ref(), + SortOptions::default(), + )?; + let len = map_array.keys().len().min(key_array.len()); + let values = (0..len).map(|i| comparator(i, i).is_eq()).collect(); + let nulls = + NullBuffer::union(map_array.keys().nulls(), key_array.nulls()); + BooleanArray::new(values, nulls) + } else { + let be_compared = Scalar::new(key_array); + arrow::compute::kernels::cmp::eq(&be_compared, map_array.keys())? + }; + + let original_data = map_array.entries().column(1).to_data(); + let capacity = Capacities::Array(original_data.len()); + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], true, capacity); + + for entry in 0..map_array.len() { + let start = map_array.value_offsets()[entry] as usize; + let end = map_array.value_offsets()[entry + 1] as usize; + + let maybe_matched = keys + .slice(start, end - start) + .iter() + .enumerate() + .find(|(_, t)| t.unwrap()); + + if maybe_matched.is_none() { + mutable.extend_nulls(1); + continue; + } + let (match_offset, _) = maybe_matched.unwrap(); + mutable.extend(0, start + match_offset, start + match_offset + 1); + } + + let data = mutable.freeze(); + let data = make_array(data); + Ok(ColumnarValue::Array(data)) + } + match (array.data_type(), name) { - (DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => { - let map_array = as_map_array(array.as_ref())?; - let key_scalar: Scalar<arrow::array::GenericByteArray<arrow::datatypes::GenericStringType<i32>>> = Scalar::new(StringArray::from(vec![k.clone()])); - let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?; - - // note that this array has more entries than the expected output/input size - // because map_array is flattened - let original_data = map_array.entries().column(1).to_data(); - let capacity = Capacities::Array(original_data.len()); - let mut mutable = - MutableArrayData::with_capacities(vec![&original_data], true, - capacity); - - for entry in 0..map_array.len(){ - let start = map_array.value_offsets()[entry] as usize; - let end = map_array.value_offsets()[entry + 1] as usize; - - let maybe_matched = - keys.slice(start, end-start). - iter().enumerate(). - find(|(_, t)| t.unwrap()); - if maybe_matched.is_none() { - mutable.extend_nulls(1); - continue - } - let (match_offset,_) = maybe_matched.unwrap(); - mutable.extend(0, start + match_offset, start + match_offset + 1); + (DataType::Map(_, _), ScalarValue::List(arr)) => { + let key_array: Arc<dyn Array> = Arc::new((**arr).clone()); + process_map_array(array, key_array) + } + (DataType::Map(_, _), ScalarValue::Struct(arr)) => { + process_map_array(array, Arc::clone(arr) as Arc<dyn Array>) + } + (DataType::Map(_, _), other) => { + let data_type = other.data_type(); + if data_type.is_nested() { + exec_err!("unsupported type {:?} for map access", data_type) + } else { + process_map_array(array, other.to_array()?) } - let data = mutable.freeze(); - let data = make_array(data); - Ok(ColumnarValue::Array(data)) } (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => { let as_struct_array = as_struct_array(&array)?; diff --git a/datafusion/sqllogictest/test_files/map.slt b/datafusion/sqllogictest/test_files/map.slt index 71296b6f64..42a4ba6218 100644 --- a/datafusion/sqllogictest/test_files/map.slt +++ b/datafusion/sqllogictest/test_files/map.slt @@ -592,6 +592,43 @@ select map_extract(column1, 1), map_extract(column1, 5), map_extract(column1, 7) [NULL] [NULL] [[1, NULL, 3]] [NULL] [NULL] [NULL] +query ? +select column1[1] from map_array_table_1; +---- +[1, NULL, 3] +NULL +NULL +NULL + +query ? +select column1[-1000 + 1001] from map_array_table_1; +---- +[1, NULL, 3] +NULL +NULL +NULL + +# test for negative scenario +query ? +SELECT column1[-1] FROM map_array_table_1; +---- +NULL +NULL +NULL +NULL + +query ? +SELECT column1[1000] FROM map_array_table_1; +---- +NULL +NULL +NULL +NULL + + +query error DataFusion error: Arrow error: Invalid argument error +SELECT column1[NULL] FROM map_array_table_1; + query ??? select map_extract(column1, column2), map_extract(column1, column3), map_extract(column1, column4) from map_array_table_1; ---- @@ -722,3 +759,28 @@ drop table map_array_table_1; statement ok drop table map_array_table_2; + + +statement ok +create table tt as values(MAP{[1,2,3]:1}, MAP {{'a':1, 'b':2}:2}, MAP{true: 3}); + +# accessing using an array +query I +select column1[make_array(1, 2, 3)] from tt; +---- +1 + +# accessing using a struct +query I +select column2[{a:1, b: 2}] from tt; +---- +2 + +# accessing using Bool +query I +select column3[true] from tt; +---- +3 + +statement ok +drop table tt; --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org For additional commands, e-mail: commits-h...@datafusion.apache.org