This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 310eab006b Map access supports constant-resolvable expressions (#14712)
310eab006b is described below

commit 310eab006b26885ec558cf5c4572c73a7d824ee9
Author: Lordworms <48054792+lordwo...@users.noreply.github.com>
AuthorDate: Thu Feb 20 04:00:24 2025 -0800

    Map access supports constant-resolvable expressions (#14712)
    
    * Map access supports constant-resolvable expressions
    
    * adding tests
    
    fix clippy
    
    fix clippy
    
    fix clippy
    
    * fix clippy
---
 datafusion/functions-nested/src/planner.rs |  22 ++++--
 datafusion/functions/src/core/getfield.rs  | 111 ++++++++++++++++++-----------
 datafusion/sqllogictest/test_files/map.slt |  62 ++++++++++++++++
 3 files changed, 148 insertions(+), 47 deletions(-)

diff --git a/datafusion/functions-nested/src/planner.rs 
b/datafusion/functions-nested/src/planner.rs
index d55176a42c..369eaecb19 100644
--- a/datafusion/functions-nested/src/planner.rs
+++ b/datafusion/functions-nested/src/planner.rs
@@ -17,17 +17,20 @@
 
 //! SQL planning extensions like [`NestedFunctionPlanner`] and 
[`FieldAccessPlanner`]
 
-use std::sync::Arc;
-
+use arrow::datatypes::DataType;
+use datafusion_common::ExprSchema;
 use datafusion_common::{plan_err, utils::list_ndims, DFSchema, Result};
-use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams, 
ScalarFunction};
+use datafusion_expr::expr::ScalarFunction;
+use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
 use datafusion_expr::AggregateUDF;
 use datafusion_expr::{
     planner::{ExprPlanner, PlannerResult, RawBinaryExpr, RawFieldAccessExpr},
     sqlparser, Expr, ExprSchemable, GetFieldAccess,
 };
+use datafusion_functions::core::get_field as get_field_inner;
 use datafusion_functions::expr_fn::get_field;
 use datafusion_functions_aggregate::nth_value::nth_value_udaf;
+use std::sync::Arc;
 
 use crate::map::map_udf;
 use crate::{
@@ -140,7 +143,7 @@ impl ExprPlanner for FieldAccessPlanner {
     fn plan_field_access(
         &self,
         expr: RawFieldAccessExpr,
-        _schema: &DFSchema,
+        schema: &DFSchema,
     ) -> Result<PlannerResult<RawFieldAccessExpr>> {
         let RawFieldAccessExpr { expr, field_access } = expr;
 
@@ -173,6 +176,17 @@ impl ExprPlanner for FieldAccessPlanner {
                             null_treatment,
                         )),
                     )),
+                    // special case for map access with
+                    Expr::Column(ref c)
+                        if matches!(schema.data_type(c)?, DataType::Map(_, _)) 
=>
+                    {
+                        Ok(PlannerResult::Planned(Expr::ScalarFunction(
+                            ScalarFunction::new_udf(
+                                get_field_inner(),
+                                vec![expr, *index],
+                            ),
+                        )))
+                    }
                     _ => Ok(PlannerResult::Planned(array_element(expr, 
*index))),
                 }
             }
diff --git a/datafusion/functions/src/core/getfield.rs 
b/datafusion/functions/src/core/getfield.rs
index d667d0d8c1..d900ee5825 100644
--- a/datafusion/functions/src/core/getfield.rs
+++ b/datafusion/functions/src/core/getfield.rs
@@ -16,9 +16,12 @@
 // under the License.
 
 use arrow::array::{
-    make_array, Array, Capacities, MutableArrayData, Scalar, StringArray,
+    make_array, make_comparator, Array, BooleanArray, Capacities, 
MutableArrayData,
+    Scalar,
 };
+use arrow::compute::SortOptions;
 use arrow::datatypes::DataType;
+use arrow_buffer::NullBuffer;
 use datafusion_common::cast::{as_map_array, as_struct_array};
 use datafusion_common::{
     exec_err, internal_err, plan_datafusion_err, utils::take_function_args, 
Result,
@@ -106,11 +109,7 @@ impl ScalarUDFImpl for GetFieldFunc {
 
         let name = match field_name {
             Expr::Literal(name) => name,
-            _ => {
-                return exec_err!(
-                    "get_field function requires the argument field_name to be 
a string"
-                );
-            }
+            other => &ScalarValue::Utf8(Some(other.schema_name().to_string())),
         };
 
         Ok(format!("{base}[{name}]"))
@@ -118,14 +117,9 @@ impl ScalarUDFImpl for GetFieldFunc {
 
     fn schema_name(&self, args: &[Expr]) -> Result<String> {
         let [base, field_name] = take_function_args(self.name(), args)?;
-
         let name = match field_name {
             Expr::Literal(name) => name,
-            _ => {
-                return exec_err!(
-                    "get_field function requires the argument field_name to be 
a string"
-                );
-            }
+            other => &ScalarValue::Utf8(Some(other.schema_name().to_string())),
         };
 
         Ok(format!("{}[{}]", base.schema_name(), name))
@@ -182,7 +176,6 @@ impl ScalarUDFImpl for GetFieldFunc {
         let arrays =
             ColumnarValue::values_to_arrays(&[base.clone(), 
field_name.clone()])?;
         let array = Arc::clone(&arrays[0]);
-
         let name = match field_name {
             ColumnarValue::Scalar(name) => name,
             _ => {
@@ -192,38 +185,70 @@ impl ScalarUDFImpl for GetFieldFunc {
             }
         };
 
+        fn process_map_array(
+            array: Arc<dyn Array>,
+            key_array: Arc<dyn Array>,
+        ) -> Result<ColumnarValue> {
+            let map_array = as_map_array(array.as_ref())?;
+            let keys = if key_array.data_type().is_nested() {
+                let comparator = make_comparator(
+                    map_array.keys().as_ref(),
+                    key_array.as_ref(),
+                    SortOptions::default(),
+                )?;
+                let len = map_array.keys().len().min(key_array.len());
+                let values = (0..len).map(|i| comparator(i, 
i).is_eq()).collect();
+                let nulls =
+                    NullBuffer::union(map_array.keys().nulls(), 
key_array.nulls());
+                BooleanArray::new(values, nulls)
+            } else {
+                let be_compared = Scalar::new(key_array);
+                arrow::compute::kernels::cmp::eq(&be_compared, 
map_array.keys())?
+            };
+
+            let original_data = map_array.entries().column(1).to_data();
+            let capacity = Capacities::Array(original_data.len());
+            let mut mutable =
+                MutableArrayData::with_capacities(vec![&original_data], true, 
capacity);
+
+            for entry in 0..map_array.len() {
+                let start = map_array.value_offsets()[entry] as usize;
+                let end = map_array.value_offsets()[entry + 1] as usize;
+
+                let maybe_matched = keys
+                    .slice(start, end - start)
+                    .iter()
+                    .enumerate()
+                    .find(|(_, t)| t.unwrap());
+
+                if maybe_matched.is_none() {
+                    mutable.extend_nulls(1);
+                    continue;
+                }
+                let (match_offset, _) = maybe_matched.unwrap();
+                mutable.extend(0, start + match_offset, start + match_offset + 
1);
+            }
+
+            let data = mutable.freeze();
+            let data = make_array(data);
+            Ok(ColumnarValue::Array(data))
+        }
+
         match (array.data_type(), name) {
-            (DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => {
-                let map_array = as_map_array(array.as_ref())?;
-                let key_scalar: 
Scalar<arrow::array::GenericByteArray<arrow::datatypes::GenericStringType<i32>>>
 = Scalar::new(StringArray::from(vec![k.clone()]));
-                let keys = arrow::compute::kernels::cmp::eq(&key_scalar, 
map_array.keys())?;
-
-                // note that this array has more entries than the expected 
output/input size
-                // because map_array is flattened
-                let original_data =  map_array.entries().column(1).to_data();
-                let capacity = Capacities::Array(original_data.len());
-                let mut mutable =
-                    MutableArrayData::with_capacities(vec![&original_data], 
true,
-                         capacity);
-
-                for entry in 0..map_array.len(){
-                    let start = map_array.value_offsets()[entry] as usize;
-                    let end = map_array.value_offsets()[entry + 1] as usize;
-
-                    let maybe_matched =
-                                        keys.slice(start, end-start).
-                                        iter().enumerate().
-                                        find(|(_, t)| t.unwrap());
-                    if maybe_matched.is_none() {
-                        mutable.extend_nulls(1);
-                        continue
-                    }
-                    let (match_offset,_) = maybe_matched.unwrap();
-                    mutable.extend(0, start + match_offset, start + 
match_offset + 1);
+            (DataType::Map(_, _), ScalarValue::List(arr)) => {
+                let key_array: Arc<dyn Array> = Arc::new((**arr).clone());
+                process_map_array(array, key_array)
+            }
+            (DataType::Map(_, _), ScalarValue::Struct(arr)) => {
+                process_map_array(array, Arc::clone(arr) as Arc<dyn Array>)
+            }
+            (DataType::Map(_, _), other) => {
+                let data_type = other.data_type();
+                if data_type.is_nested() {
+                    exec_err!("unsupported type {:?} for map access", 
data_type)
+                } else {
+                    process_map_array(array, other.to_array()?)
                 }
-                let data = mutable.freeze();
-                let data = make_array(data);
-                Ok(ColumnarValue::Array(data))
             }
             (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => {
                 let as_struct_array = as_struct_array(&array)?;
diff --git a/datafusion/sqllogictest/test_files/map.slt 
b/datafusion/sqllogictest/test_files/map.slt
index 71296b6f64..42a4ba6218 100644
--- a/datafusion/sqllogictest/test_files/map.slt
+++ b/datafusion/sqllogictest/test_files/map.slt
@@ -592,6 +592,43 @@ select map_extract(column1, 1), map_extract(column1, 5), 
map_extract(column1, 7)
 [NULL] [NULL] [[1, NULL, 3]]
 [NULL] [NULL] [NULL]
 
+query ?
+select column1[1] from map_array_table_1;
+----
+[1, NULL, 3]
+NULL
+NULL
+NULL
+
+query ?
+select column1[-1000 + 1001] from map_array_table_1;
+----
+[1, NULL, 3]
+NULL
+NULL
+NULL
+
+# test for negative scenario
+query ?
+SELECT column1[-1] FROM map_array_table_1;
+----
+NULL
+NULL
+NULL
+NULL
+
+query ?
+SELECT column1[1000] FROM map_array_table_1;
+----
+NULL
+NULL
+NULL
+NULL
+
+
+query error DataFusion error: Arrow error: Invalid argument error
+SELECT column1[NULL] FROM map_array_table_1;
+
 query ???
 select map_extract(column1, column2), map_extract(column1, column3), 
map_extract(column1, column4) from map_array_table_1;
 ----
@@ -722,3 +759,28 @@ drop table map_array_table_1;
 
 statement ok
 drop table map_array_table_2;
+
+
+statement ok
+create table tt as values(MAP{[1,2,3]:1}, MAP {{'a':1, 'b':2}:2}, MAP{true: 
3});
+
+# accessing using an array
+query I
+select column1[make_array(1, 2, 3)] from tt;
+----
+1
+
+# accessing using a struct
+query I
+select column2[{a:1, b: 2}] from tt;
+----
+2
+
+# accessing using Bool
+query I
+select column3[true] from tt;
+----
+3
+
+statement ok
+drop table tt;


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to