This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 0f2a68ee16 Validate ScalarUDF output rows and fix nulls for
`array_has` and `get_field` for `Map` (#10148)
0f2a68ee16 is described below
commit 0f2a68ee1676c0d141d2c7cacf4b7c21d0033870
Author: Duong Cong Toai <[email protected]>
AuthorDate: Mon Apr 29 21:27:40 2024 +0200
Validate ScalarUDF output rows and fix nulls for `array_has` and
`get_field` for `Map` (#10148)
* validate input/output of udf
* clip
* fmt
* clean garbage
* don't check if output is scalar
* lint
* fix array_has
* rm debug
* chore: temp code for demonstration
* getfield retains number of rows
* rust fmt
* minor comments
* fmt
* refactor
* compile err
* fmt again
* fmt
* add validate_number_of_rows for UDF
* only check for columnarvalue::array
---
.../user_defined/user_defined_scalar_functions.rs | 40 ++++++++++++-
datafusion/functions-array/src/array_has.rs | 60 ++++++++++---------
datafusion/functions/src/core/getfield.rs | 70 +++++++++++++++-------
datafusion/physical-expr/src/scalar_function.rs | 15 +++--
datafusion/sqllogictest/test_files/array.slt | 15 +++--
datafusion/sqllogictest/test_files/map.slt | 1 +
6 files changed, 141 insertions(+), 60 deletions(-)
diff --git
a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
index 4f262b54fb..def9fcb4c6 100644
--- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
+++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
@@ -26,7 +26,7 @@ use datafusion_common::{
assert_batches_eq, assert_batches_sorted_eq, cast::as_float64_array,
cast::as_int32_array, not_impl_err, plan_err, ExprSchema, Result,
ScalarValue,
};
-use datafusion_common::{exec_err, internal_err, DataFusionError};
+use datafusion_common::{assert_contains, exec_err, internal_err,
DataFusionError};
use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
use datafusion_expr::{
@@ -205,6 +205,44 @@ impl ScalarUDFImpl for Simple0ArgsScalarUDF {
}
}
+#[tokio::test]
+async fn test_row_mismatch_error_in_scalar_udf() -> Result<()> {
+ let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
+
+ let batch = RecordBatch::try_new(
+ Arc::new(schema.clone()),
+ vec![Arc::new(Int32Array::from(vec![1, 2]))],
+ )?;
+
+ let ctx = SessionContext::new();
+
+ ctx.register_batch("t", batch)?;
+
+ // udf that always return 1 row
+ let buggy_udf = Arc::new(|_: &[ColumnarValue]| {
+ Ok(ColumnarValue::Array(Arc::new(Int32Array::from(vec![0]))))
+ });
+
+ ctx.register_udf(create_udf(
+ "buggy_func",
+ vec![DataType::Int32],
+ Arc::new(DataType::Int32),
+ Volatility::Immutable,
+ buggy_udf,
+ ));
+ assert_contains!(
+ ctx.sql("select buggy_func(a) from t")
+ .await?
+ .show()
+ .await
+ .err()
+ .unwrap()
+ .to_string(),
+ "UDF returned a different number of rows than expected"
+ );
+ Ok(())
+}
+
#[tokio::test]
async fn scalar_udf_zero_params() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
diff --git a/datafusion/functions-array/src/array_has.rs
b/datafusion/functions-array/src/array_has.rs
index ee064335c1..e5e8add95f 100644
--- a/datafusion/functions-array/src/array_has.rs
+++ b/datafusion/functions-array/src/array_has.rs
@@ -288,36 +288,40 @@ fn general_array_has_dispatch<O: OffsetSizeTrait>(
} else {
array
};
-
for (row_idx, (arr, sub_arr)) in
array.iter().zip(sub_array.iter()).enumerate() {
- if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) {
- let arr_values = converter.convert_columns(&[arr])?;
- let sub_arr_values = if comparison_type != ComparisonType::Single {
- converter.convert_columns(&[sub_arr])?
- } else {
- converter.convert_columns(&[element.clone()])?
- };
-
- let mut res = match comparison_type {
- ComparisonType::All => sub_arr_values
- .iter()
- .dedup()
- .all(|elem| arr_values.iter().dedup().any(|x| x == elem)),
- ComparisonType::Any => sub_arr_values
- .iter()
- .dedup()
- .any(|elem| arr_values.iter().dedup().any(|x| x == elem)),
- ComparisonType::Single => arr_values
- .iter()
- .dedup()
- .any(|x| x == sub_arr_values.row(row_idx)),
- };
-
- if comparison_type == ComparisonType::Any {
- res |= res;
+ match (arr, sub_arr) {
+ (Some(arr), Some(sub_arr)) => {
+ let arr_values = converter.convert_columns(&[arr])?;
+ let sub_arr_values = if comparison_type !=
ComparisonType::Single {
+ converter.convert_columns(&[sub_arr])?
+ } else {
+ converter.convert_columns(&[element.clone()])?
+ };
+
+ let mut res = match comparison_type {
+ ComparisonType::All => sub_arr_values
+ .iter()
+ .dedup()
+ .all(|elem| arr_values.iter().dedup().any(|x| x ==
elem)),
+ ComparisonType::Any => sub_arr_values
+ .iter()
+ .dedup()
+ .any(|elem| arr_values.iter().dedup().any(|x| x ==
elem)),
+ ComparisonType::Single => arr_values
+ .iter()
+ .dedup()
+ .any(|x| x == sub_arr_values.row(row_idx)),
+ };
+
+ if comparison_type == ComparisonType::Any {
+ res |= res;
+ }
+ boolean_builder.append_value(res);
+ }
+ // respect null input
+ (_, _) => {
+ boolean_builder.append_null();
}
-
- boolean_builder.append_value(res);
}
}
Ok(Arc::new(boolean_builder.finish()))
diff --git a/datafusion/functions/src/core/getfield.rs
b/datafusion/functions/src/core/getfield.rs
index b00b8ea553..a092aac159 100644
--- a/datafusion/functions/src/core/getfield.rs
+++ b/datafusion/functions/src/core/getfield.rs
@@ -15,7 +15,9 @@
// specific language governing permissions and limitations
// under the License.
-use arrow::array::{Scalar, StringArray};
+use arrow::array::{
+ make_array, Array, Capacities, MutableArrayData, Scalar, StringArray,
+};
use arrow::datatypes::DataType;
use datafusion_common::cast::{as_map_array, as_struct_array};
use datafusion_common::{exec_err, ExprSchema, Result, ScalarValue};
@@ -107,29 +109,55 @@ impl ScalarUDFImpl for GetFieldFunc {
);
}
};
+
match (array.data_type(), name) {
- (DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => {
- let map_array = as_map_array(array.as_ref())?;
- let key_scalar =
Scalar::new(StringArray::from(vec![k.clone()]));
- let keys = arrow::compute::kernels::cmp::eq(&key_scalar,
map_array.keys())?;
- let entries = arrow::compute::filter(map_array.entries(),
&keys)?;
- let entries_struct_array =
as_struct_array(entries.as_ref())?;
-
Ok(ColumnarValue::Array(entries_struct_array.column(1).clone()))
- }
- (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => {
- let as_struct_array = as_struct_array(&array)?;
- match as_struct_array.column_by_name(k) {
- None => exec_err!(
- "get indexed field {k} not found in struct"),
- Some(col) => Ok(ColumnarValue::Array(col.clone()))
+ (DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => {
+ let map_array = as_map_array(array.as_ref())?;
+ let key_scalar:
Scalar<arrow::array::GenericByteArray<arrow::datatypes::GenericStringType<i32>>>
= Scalar::new(StringArray::from(vec![k.clone()]));
+ let keys = arrow::compute::kernels::cmp::eq(&key_scalar,
map_array.keys())?;
+
+ // note that this array has more entries than the expected
output/input size
+ // because maparray is flatten
+ let original_data = map_array.entries().column(1).to_data();
+ let capacity = Capacities::Array(original_data.len());
+ let mut mutable =
+ MutableArrayData::with_capacities(vec![&original_data],
true,
+ capacity);
+
+ for entry in 0..map_array.len(){
+ let start = map_array.value_offsets()[entry] as usize;
+ let end = map_array.value_offsets()[entry + 1] as usize;
+
+ let maybe_matched =
+ keys.slice(start, end-start).
+ iter().enumerate().
+ find(|(_, t)| t.unwrap());
+ if maybe_matched.is_none(){
+ mutable.extend_nulls(1);
+ continue
}
+ let (match_offset,_) = maybe_matched.unwrap();
+ mutable.extend(0, start + match_offset, start +
match_offset + 1);
+ }
+ let data = mutable.freeze();
+ let data = make_array(data);
+ Ok(ColumnarValue::Array(data))
+ }
+ (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => {
+ let as_struct_array = as_struct_array(&array)?;
+ match as_struct_array.column_by_name(k) {
+ None => exec_err!("get indexed field {k} not found in
struct"),
+ Some(col) => Ok(ColumnarValue::Array(col.clone())),
}
- (DataType::Struct(_), name) => exec_err!(
- "get indexed field is only possible on struct with utf8
indexes. \
- Tried with {name:?} index"),
- (dt, name) => exec_err!(
- "get indexed field is only possible on lists
with int64 indexes or struct \
- with utf8 indexes. Tried {dt:?} with
{name:?} index"),
}
+ (DataType::Struct(_), name) => exec_err!(
+ "get indexed field is only possible on struct with utf8
indexes. \
+ Tried with {name:?} index"
+ ),
+ (dt, name) => exec_err!(
+ "get indexed field is only possible on lists with int64
indexes or struct \
+ with utf8 indexes. Tried {dt:?} with
{name:?} index"
+ ),
+ }
}
}
diff --git a/datafusion/physical-expr/src/scalar_function.rs
b/datafusion/physical-expr/src/scalar_function.rs
index 3b360fc20c..b9c6ff3cfe 100644
--- a/datafusion/physical-expr/src/scalar_function.rs
+++ b/datafusion/physical-expr/src/scalar_function.rs
@@ -146,11 +146,18 @@ impl PhysicalExpr for ScalarFunctionExpr {
// evaluate the function
match self.fun {
ScalarFunctionDefinition::UDF(ref fun) => {
- if self.args.is_empty() {
- fun.invoke_no_args(batch.num_rows())
- } else {
- fun.invoke(&inputs)
+ let output = match self.args.is_empty() {
+ true => fun.invoke_no_args(batch.num_rows()),
+ false => fun.invoke(&inputs),
+ }?;
+
+ if let ColumnarValue::Array(array) = &output {
+ if array.len() != batch.num_rows() {
+ return internal_err!("UDF returned a different number
of rows than expected. Expected: {}, Got: {}",
+ batch.num_rows(), array.len());
+ }
}
+ Ok(output)
}
ScalarFunctionDefinition::Name(_) => {
internal_err!(
diff --git a/datafusion/sqllogictest/test_files/array.slt
b/datafusion/sqllogictest/test_files/array.slt
index c3c5603daf..b33419ecd4 100644
--- a/datafusion/sqllogictest/test_files/array.slt
+++ b/datafusion/sqllogictest/test_files/array.slt
@@ -5169,8 +5169,9 @@ false false false true
true false true false
true false false true
false true false false
-false false false false
-false false false false
+NULL NULL false false
+false false NULL false
+false false false NULL
query BBBB
select array_has(arrow_cast(column1, 'LargeList(List(Int64))'), make_array(5,
6)),
@@ -5183,8 +5184,9 @@ false false false true
true false true false
true false false true
false true false false
-false false false false
-false false false false
+NULL NULL false false
+false false NULL false
+false false false NULL
query BBBB
select array_has(column1, make_array(5, 6)),
@@ -5197,8 +5199,9 @@ false false false true
true false true false
true false false true
false true false false
-false false false false
-false false false false
+NULL NULL false false
+false false NULL false
+false false false NULL
query BBBBBBBBBBBBB
select array_has_all(make_array(1,2,3), make_array(1,3)),
diff --git a/datafusion/sqllogictest/test_files/map.slt
b/datafusion/sqllogictest/test_files/map.slt
index 415fabf224..8ff7d119c4 100644
--- a/datafusion/sqllogictest/test_files/map.slt
+++ b/datafusion/sqllogictest/test_files/map.slt
@@ -44,6 +44,7 @@ DELETE 24
query T
SELECT strings['not_found'] FROM data LIMIT 1;
----
+NULL
statement ok
drop table data;
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]