This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 0f2a68ee16 Validate ScalarUDF output rows and fix nulls for 
`array_has` and `get_field` for `Map` (#10148)
0f2a68ee16 is described below

commit 0f2a68ee1676c0d141d2c7cacf4b7c21d0033870
Author: Duong Cong Toai <[email protected]>
AuthorDate: Mon Apr 29 21:27:40 2024 +0200

    Validate ScalarUDF output rows and fix nulls for `array_has` and 
`get_field` for `Map` (#10148)
    
    * validate input/output of udf
    
    * clip
    
    * fmt
    
    * clean garbage
    
    * don't check if output is scalar
    
    * lint
    
    * fix array_has
    
    * rm debug
    
    * chore: temp code for demonstration
    
    * getfield retains number of rows
    
    * rust fmt
    
    * minor comments
    
    * fmt
    
    * refactor
    
    * compile err
    
    * fmt again
    
    * fmt
    
    * add validate_number_of_rows for UDF
    
    * only check for columnarvalue::array
---
 .../user_defined/user_defined_scalar_functions.rs  | 40 ++++++++++++-
 datafusion/functions-array/src/array_has.rs        | 60 ++++++++++---------
 datafusion/functions/src/core/getfield.rs          | 70 +++++++++++++++-------
 datafusion/physical-expr/src/scalar_function.rs    | 15 +++--
 datafusion/sqllogictest/test_files/array.slt       | 15 +++--
 datafusion/sqllogictest/test_files/map.slt         |  1 +
 6 files changed, 141 insertions(+), 60 deletions(-)

diff --git 
a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs 
b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
index 4f262b54fb..def9fcb4c6 100644
--- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
+++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
@@ -26,7 +26,7 @@ use datafusion_common::{
     assert_batches_eq, assert_batches_sorted_eq, cast::as_float64_array,
     cast::as_int32_array, not_impl_err, plan_err, ExprSchema, Result, 
ScalarValue,
 };
-use datafusion_common::{exec_err, internal_err, DataFusionError};
+use datafusion_common::{assert_contains, exec_err, internal_err, 
DataFusionError};
 use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv};
 use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
 use datafusion_expr::{
@@ -205,6 +205,44 @@ impl ScalarUDFImpl for Simple0ArgsScalarUDF {
     }
 }
 
+#[tokio::test]
+async fn test_row_mismatch_error_in_scalar_udf() -> Result<()> {
+    let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
+
+    let batch = RecordBatch::try_new(
+        Arc::new(schema.clone()),
+        vec![Arc::new(Int32Array::from(vec![1, 2]))],
+    )?;
+
+    let ctx = SessionContext::new();
+
+    ctx.register_batch("t", batch)?;
+
+    // udf that always return 1 row
+    let buggy_udf = Arc::new(|_: &[ColumnarValue]| {
+        Ok(ColumnarValue::Array(Arc::new(Int32Array::from(vec![0]))))
+    });
+
+    ctx.register_udf(create_udf(
+        "buggy_func",
+        vec![DataType::Int32],
+        Arc::new(DataType::Int32),
+        Volatility::Immutable,
+        buggy_udf,
+    ));
+    assert_contains!(
+        ctx.sql("select buggy_func(a) from t")
+            .await?
+            .show()
+            .await
+            .err()
+            .unwrap()
+            .to_string(),
+        "UDF returned a different number of rows than expected"
+    );
+    Ok(())
+}
+
 #[tokio::test]
 async fn scalar_udf_zero_params() -> Result<()> {
     let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
diff --git a/datafusion/functions-array/src/array_has.rs 
b/datafusion/functions-array/src/array_has.rs
index ee064335c1..e5e8add95f 100644
--- a/datafusion/functions-array/src/array_has.rs
+++ b/datafusion/functions-array/src/array_has.rs
@@ -288,36 +288,40 @@ fn general_array_has_dispatch<O: OffsetSizeTrait>(
     } else {
         array
     };
-
     for (row_idx, (arr, sub_arr)) in 
array.iter().zip(sub_array.iter()).enumerate() {
-        if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) {
-            let arr_values = converter.convert_columns(&[arr])?;
-            let sub_arr_values = if comparison_type != ComparisonType::Single {
-                converter.convert_columns(&[sub_arr])?
-            } else {
-                converter.convert_columns(&[element.clone()])?
-            };
-
-            let mut res = match comparison_type {
-                ComparisonType::All => sub_arr_values
-                    .iter()
-                    .dedup()
-                    .all(|elem| arr_values.iter().dedup().any(|x| x == elem)),
-                ComparisonType::Any => sub_arr_values
-                    .iter()
-                    .dedup()
-                    .any(|elem| arr_values.iter().dedup().any(|x| x == elem)),
-                ComparisonType::Single => arr_values
-                    .iter()
-                    .dedup()
-                    .any(|x| x == sub_arr_values.row(row_idx)),
-            };
-
-            if comparison_type == ComparisonType::Any {
-                res |= res;
+        match (arr, sub_arr) {
+            (Some(arr), Some(sub_arr)) => {
+                let arr_values = converter.convert_columns(&[arr])?;
+                let sub_arr_values = if comparison_type != 
ComparisonType::Single {
+                    converter.convert_columns(&[sub_arr])?
+                } else {
+                    converter.convert_columns(&[element.clone()])?
+                };
+
+                let mut res = match comparison_type {
+                    ComparisonType::All => sub_arr_values
+                        .iter()
+                        .dedup()
+                        .all(|elem| arr_values.iter().dedup().any(|x| x == 
elem)),
+                    ComparisonType::Any => sub_arr_values
+                        .iter()
+                        .dedup()
+                        .any(|elem| arr_values.iter().dedup().any(|x| x == 
elem)),
+                    ComparisonType::Single => arr_values
+                        .iter()
+                        .dedup()
+                        .any(|x| x == sub_arr_values.row(row_idx)),
+                };
+
+                if comparison_type == ComparisonType::Any {
+                    res |= res;
+                }
+                boolean_builder.append_value(res);
+            }
+            // respect null input
+            (_, _) => {
+                boolean_builder.append_null();
             }
-
-            boolean_builder.append_value(res);
         }
     }
     Ok(Arc::new(boolean_builder.finish()))
diff --git a/datafusion/functions/src/core/getfield.rs 
b/datafusion/functions/src/core/getfield.rs
index b00b8ea553..a092aac159 100644
--- a/datafusion/functions/src/core/getfield.rs
+++ b/datafusion/functions/src/core/getfield.rs
@@ -15,7 +15,9 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use arrow::array::{Scalar, StringArray};
+use arrow::array::{
+    make_array, Array, Capacities, MutableArrayData, Scalar, StringArray,
+};
 use arrow::datatypes::DataType;
 use datafusion_common::cast::{as_map_array, as_struct_array};
 use datafusion_common::{exec_err, ExprSchema, Result, ScalarValue};
@@ -107,29 +109,55 @@ impl ScalarUDFImpl for GetFieldFunc {
                 );
             }
         };
+
         match (array.data_type(), name) {
-                (DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => {
-                    let map_array = as_map_array(array.as_ref())?;
-                    let key_scalar = 
Scalar::new(StringArray::from(vec![k.clone()]));
-                    let keys = arrow::compute::kernels::cmp::eq(&key_scalar, 
map_array.keys())?;
-                    let entries = arrow::compute::filter(map_array.entries(), 
&keys)?;
-                    let entries_struct_array = 
as_struct_array(entries.as_ref())?;
-                    
Ok(ColumnarValue::Array(entries_struct_array.column(1).clone()))
-                }
-                (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => {
-                    let as_struct_array = as_struct_array(&array)?;
-                    match as_struct_array.column_by_name(k) {
-                        None => exec_err!(
-                            "get indexed field {k} not found in struct"),
-                        Some(col) => Ok(ColumnarValue::Array(col.clone()))
+            (DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => {
+                let map_array = as_map_array(array.as_ref())?;
+                let key_scalar: 
Scalar<arrow::array::GenericByteArray<arrow::datatypes::GenericStringType<i32>>>
 = Scalar::new(StringArray::from(vec![k.clone()]));
+                let keys = arrow::compute::kernels::cmp::eq(&key_scalar, 
map_array.keys())?;
+
+                // note that this array has more entries than the expected 
output/input size
+                // because maparray is flatten
+                let original_data =  map_array.entries().column(1).to_data();
+                let capacity = Capacities::Array(original_data.len());
+                let mut mutable =
+                    MutableArrayData::with_capacities(vec![&original_data], 
true,
+                         capacity);
+
+                for entry in 0..map_array.len(){
+                    let start = map_array.value_offsets()[entry] as usize;
+                    let end = map_array.value_offsets()[entry + 1] as usize;
+
+                    let maybe_matched =
+                                        keys.slice(start, end-start).
+                                        iter().enumerate().
+                                        find(|(_, t)| t.unwrap());
+                    if maybe_matched.is_none(){
+                        mutable.extend_nulls(1);
+                        continue
                     }
+                    let (match_offset,_) = maybe_matched.unwrap();
+                    mutable.extend(0, start + match_offset, start + 
match_offset + 1);
+                }
+                let data = mutable.freeze();
+                let data = make_array(data);
+                Ok(ColumnarValue::Array(data))
+            }
+            (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => {
+                let as_struct_array = as_struct_array(&array)?;
+                match as_struct_array.column_by_name(k) {
+                    None => exec_err!("get indexed field {k} not found in 
struct"),
+                    Some(col) => Ok(ColumnarValue::Array(col.clone())),
                 }
-                (DataType::Struct(_), name) => exec_err!(
-                    "get indexed field is only possible on struct with utf8 
indexes. \
-                             Tried with {name:?} index"),
-                (dt, name) => exec_err!(
-                                "get indexed field is only possible on lists 
with int64 indexes or struct \
-                                         with utf8 indexes. Tried {dt:?} with 
{name:?} index"),
             }
+            (DataType::Struct(_), name) => exec_err!(
+                "get indexed field is only possible on struct with utf8 
indexes. \
+                             Tried with {name:?} index"
+            ),
+            (dt, name) => exec_err!(
+                "get indexed field is only possible on lists with int64 
indexes or struct \
+                                         with utf8 indexes. Tried {dt:?} with 
{name:?} index"
+            ),
+        }
     }
 }
diff --git a/datafusion/physical-expr/src/scalar_function.rs 
b/datafusion/physical-expr/src/scalar_function.rs
index 3b360fc20c..b9c6ff3cfe 100644
--- a/datafusion/physical-expr/src/scalar_function.rs
+++ b/datafusion/physical-expr/src/scalar_function.rs
@@ -146,11 +146,18 @@ impl PhysicalExpr for ScalarFunctionExpr {
         // evaluate the function
         match self.fun {
             ScalarFunctionDefinition::UDF(ref fun) => {
-                if self.args.is_empty() {
-                    fun.invoke_no_args(batch.num_rows())
-                } else {
-                    fun.invoke(&inputs)
+                let output = match self.args.is_empty() {
+                    true => fun.invoke_no_args(batch.num_rows()),
+                    false => fun.invoke(&inputs),
+                }?;
+
+                if let ColumnarValue::Array(array) = &output {
+                    if array.len() != batch.num_rows() {
+                        return internal_err!("UDF returned a different number 
of rows than expected. Expected: {}, Got: {}", 
+                        batch.num_rows(), array.len());
+                    }
                 }
+                Ok(output)
             }
             ScalarFunctionDefinition::Name(_) => {
                 internal_err!(
diff --git a/datafusion/sqllogictest/test_files/array.slt 
b/datafusion/sqllogictest/test_files/array.slt
index c3c5603daf..b33419ecd4 100644
--- a/datafusion/sqllogictest/test_files/array.slt
+++ b/datafusion/sqllogictest/test_files/array.slt
@@ -5169,8 +5169,9 @@ false false false true
 true false true false
 true false false true
 false true false false
-false false false false
-false false false false
+NULL NULL false false
+false false NULL false
+false false false NULL
 
 query BBBB
 select array_has(arrow_cast(column1, 'LargeList(List(Int64))'), make_array(5, 
6)),
@@ -5183,8 +5184,9 @@ false false false true
 true false true false
 true false false true
 false true false false
-false false false false
-false false false false
+NULL NULL false false
+false false NULL false
+false false false NULL
 
 query BBBB
 select array_has(column1, make_array(5, 6)),
@@ -5197,8 +5199,9 @@ false false false true
 true false true false
 true false false true
 false true false false
-false false false false
-false false false false
+NULL NULL false false
+false false NULL false
+false false false NULL
 
 query BBBBBBBBBBBBB
 select array_has_all(make_array(1,2,3), make_array(1,3)),
diff --git a/datafusion/sqllogictest/test_files/map.slt 
b/datafusion/sqllogictest/test_files/map.slt
index 415fabf224..8ff7d119c4 100644
--- a/datafusion/sqllogictest/test_files/map.slt
+++ b/datafusion/sqllogictest/test_files/map.slt
@@ -44,6 +44,7 @@ DELETE 24
 query T
 SELECT strings['not_found'] FROM data LIMIT 1;
 ----
+NULL
 
 statement ok
 drop table data;


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to