This is an automated email from the ASF dual-hosted git repository.
jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 06ee8a4439 Issue-9765 - Extract array_resize and cardinality functions
from functions-array subcrate' s kernels and udf containers (#9766)
06ee8a4439 is described below
commit 06ee8a443911797d97fc216a4fc8a29b7604dca8
Author: Eren Avsarogullari <[email protected]>
AuthorDate: Sun Mar 24 00:01:44 2024 -0700
Issue-9765 - Extract array_resize and cardinality functions from
functions-array subcrate' s kernels and udf containers (#9766)
---
datafusion/functions-array/src/cardinality.rs | 115 +++++++++++++++++
datafusion/functions-array/src/kernels.rs | 165 +-----------------------
datafusion/functions-array/src/lib.rs | 10 +-
datafusion/functions-array/src/resize.rs | 179 ++++++++++++++++++++++++++
datafusion/functions-array/src/udf.rs | 110 ----------------
datafusion/functions-array/src/utils.rs | 49 +++++--
6 files changed, 343 insertions(+), 285 deletions(-)
diff --git a/datafusion/functions-array/src/cardinality.rs
b/datafusion/functions-array/src/cardinality.rs
new file mode 100644
index 0000000000..483336fe08
--- /dev/null
+++ b/datafusion/functions-array/src/cardinality.rs
@@ -0,0 +1,115 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! [`ScalarUDFImpl`] definitions for cardinality function.
+
+use crate::utils::make_scalar_function;
+use arrow_array::{ArrayRef, GenericListArray, OffsetSizeTrait, UInt64Array};
+use arrow_schema::DataType;
+use arrow_schema::DataType::{FixedSizeList, LargeList, List, UInt64};
+use datafusion_common::cast::{as_large_list_array, as_list_array};
+use datafusion_common::{exec_err, plan_err};
+use datafusion_expr::expr::ScalarFunction;
+use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature,
Volatility};
+use std::any::Any;
+use std::sync::Arc;
+
+make_udf_function!(
+ Cardinality,
+ cardinality,
+ array,
+ "returns the total number of elements in the array.",
+ cardinality_udf
+);
+
+impl Cardinality {
+ pub fn new() -> Self {
+ Self {
+ signature: Signature::array(Volatility::Immutable),
+ aliases: vec![String::from("cardinality")],
+ }
+ }
+}
+
+#[derive(Debug)]
+pub(super) struct Cardinality {
+ signature: Signature,
+ aliases: Vec<String>,
+}
+impl ScalarUDFImpl for Cardinality {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+ fn name(&self) -> &str {
+ "cardinality"
+ }
+
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+
+ fn return_type(&self, arg_types: &[DataType]) ->
datafusion_common::Result<DataType> {
+ Ok(match arg_types[0] {
+ List(_) | LargeList(_) | FixedSizeList(_, _) => UInt64,
+ _ => {
+ return plan_err!("The cardinality function can only accept
List/LargeList/FixedSizeList.");
+ }
+ })
+ }
+
+ fn invoke(&self, args: &[ColumnarValue]) ->
datafusion_common::Result<ColumnarValue> {
+ make_scalar_function(cardinality_inner)(args)
+ }
+
+ fn aliases(&self) -> &[String] {
+ &self.aliases
+ }
+}
+
+/// Cardinality SQL function
+pub fn cardinality_inner(args: &[ArrayRef]) ->
datafusion_common::Result<ArrayRef> {
+ if args.len() != 1 {
+ return exec_err!("cardinality expects one argument");
+ }
+
+ match &args[0].data_type() {
+ List(_) => {
+ let list_array = as_list_array(&args[0])?;
+ generic_list_cardinality::<i32>(list_array)
+ }
+ LargeList(_) => {
+ let list_array = as_large_list_array(&args[0])?;
+ generic_list_cardinality::<i64>(list_array)
+ }
+ other => {
+ exec_err!("cardinality does not support type '{:?}'", other)
+ }
+ }
+}
+
+fn generic_list_cardinality<O: OffsetSizeTrait>(
+ array: &GenericListArray<O>,
+) -> datafusion_common::Result<ArrayRef> {
+ let result = array
+ .iter()
+ .map(|arr| match crate::utils::compute_array_dims(arr)? {
+ Some(vector) => Ok(Some(vector.iter().map(|x|
x.unwrap()).product::<u64>())),
+ None => Ok(None),
+ })
+ .collect::<datafusion_common::Result<UInt64Array>>()?;
+ Ok(Arc::new(result) as ArrayRef)
+}
diff --git a/datafusion/functions-array/src/kernels.rs
b/datafusion/functions-array/src/kernels.rs
index 4745db0170..1a08b64197 100644
--- a/datafusion/functions-array/src/kernels.rs
+++ b/datafusion/functions-array/src/kernels.rs
@@ -18,80 +18,19 @@
//! implementation kernels for array functions
use arrow::array::{
- Array, ArrayRef, Capacities, GenericListArray, Int64Array, ListArray,
- MutableArrayData, OffsetSizeTrait, UInt64Array,
+ Array, ArrayRef, GenericListArray, ListArray, OffsetSizeTrait, UInt64Array,
};
use arrow::datatypes::{DataType, UInt64Type};
-use arrow_buffer::{ArrowNativeType, OffsetBuffer};
-use arrow_schema::FieldRef;
+use arrow_buffer::OffsetBuffer;
use datafusion_common::cast::{
- as_generic_list_array, as_int64_array, as_large_list_array, as_list_array,
-};
-use datafusion_common::{
- exec_err, internal_datafusion_err, DataFusionError, Result, ScalarValue,
+ as_generic_list_array, as_large_list_array, as_list_array,
};
+use datafusion_common::{exec_err, Result};
-use crate::utils::downcast_arg;
-use std::any::type_name;
+use crate::utils::compute_array_dims;
use std::sync::Arc;
-/// Returns the length of each array dimension
-fn compute_array_dims(arr: Option<ArrayRef>) ->
Result<Option<Vec<Option<u64>>>> {
- let mut value = match arr {
- Some(arr) => arr,
- None => return Ok(None),
- };
- if value.is_empty() {
- return Ok(None);
- }
- let mut res = vec![Some(value.len() as u64)];
-
- loop {
- match value.data_type() {
- DataType::List(..) => {
- value = downcast_arg!(value, ListArray).value(0);
- res.push(Some(value.len() as u64));
- }
- _ => return Ok(Some(res)),
- }
- }
-}
-
-fn generic_list_cardinality<O: OffsetSizeTrait>(
- array: &GenericListArray<O>,
-) -> Result<ArrayRef> {
- let result = array
- .iter()
- .map(|arr| match compute_array_dims(arr)? {
- Some(vector) => Ok(Some(vector.iter().map(|x|
x.unwrap()).product::<u64>())),
- None => Ok(None),
- })
- .collect::<Result<UInt64Array>>()?;
- Ok(Arc::new(result) as ArrayRef)
-}
-
-/// Cardinality SQL function
-pub fn cardinality(args: &[ArrayRef]) -> Result<ArrayRef> {
- if args.len() != 1 {
- return exec_err!("cardinality expects one argument");
- }
-
- match &args[0].data_type() {
- DataType::List(_) => {
- let list_array = as_list_array(&args[0])?;
- generic_list_cardinality::<i32>(list_array)
- }
- DataType::LargeList(_) => {
- let list_array = as_large_list_array(&args[0])?;
- generic_list_cardinality::<i64>(list_array)
- }
- other => {
- exec_err!("cardinality does not support type '{:?}'", other)
- }
- }
-}
-
/// Array_dims SQL function
pub fn array_dims(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 1 {
@@ -158,100 +97,6 @@ pub fn array_ndims(args: &[ArrayRef]) -> Result<ArrayRef> {
}
}
-/// array_resize SQL function
-pub fn array_resize(arg: &[ArrayRef]) -> Result<ArrayRef> {
- if arg.len() < 2 || arg.len() > 3 {
- return exec_err!("array_resize needs two or three arguments");
- }
-
- let new_len = as_int64_array(&arg[1])?;
- let new_element = if arg.len() == 3 {
- Some(arg[2].clone())
- } else {
- None
- };
-
- match &arg[0].data_type() {
- DataType::List(field) => {
- let array = as_list_array(&arg[0])?;
- general_list_resize::<i32>(array, new_len, field, new_element)
- }
- DataType::LargeList(field) => {
- let array = as_large_list_array(&arg[0])?;
- general_list_resize::<i64>(array, new_len, field, new_element)
- }
- array_type => exec_err!("array_resize does not support type
'{array_type:?}'."),
- }
-}
-
-/// array_resize keep the original array and append the default element to the
end
-fn general_list_resize<O: OffsetSizeTrait>(
- array: &GenericListArray<O>,
- count_array: &Int64Array,
- field: &FieldRef,
- default_element: Option<ArrayRef>,
-) -> Result<ArrayRef>
-where
- O: TryInto<i64>,
-{
- let data_type = array.value_type();
-
- let values = array.values();
- let original_data = values.to_data();
-
- // create default element array
- let default_element = if let Some(default_element) = default_element {
- default_element
- } else {
- let null_scalar = ScalarValue::try_from(&data_type)?;
- null_scalar.to_array_of_size(original_data.len())?
- };
- let default_value_data = default_element.to_data();
-
- // create a mutable array to store the original data
- let capacity = Capacities::Array(original_data.len() +
default_value_data.len());
- let mut offsets = vec![O::usize_as(0)];
- let mut mutable = MutableArrayData::with_capacities(
- vec![&original_data, &default_value_data],
- false,
- capacity,
- );
-
- for (row_index, offset_window) in array.offsets().windows(2).enumerate() {
- let count = count_array.value(row_index).to_usize().ok_or_else(|| {
- internal_datafusion_err!("array_resize: failed to convert size to
usize")
- })?;
- let count = O::usize_as(count);
- let start = offset_window[0];
- if start + count > offset_window[1] {
- let extra_count =
- (start + count - offset_window[1]).try_into().map_err(|_| {
- internal_datafusion_err!(
- "array_resize: failed to convert size to i64"
- )
- })?;
- let end = offset_window[1];
- mutable.extend(0, (start).to_usize().unwrap(),
(end).to_usize().unwrap());
- // append default element
- for _ in 0..extra_count {
- mutable.extend(1, row_index, row_index + 1);
- }
- } else {
- let end = start + count;
- mutable.extend(0, (start).to_usize().unwrap(),
(end).to_usize().unwrap());
- };
- offsets.push(offsets[row_index] + count);
- }
-
- let data = mutable.freeze();
- Ok(Arc::new(GenericListArray::<O>::try_new(
- field.clone(),
- OffsetBuffer::<O>::new(offsets.into()),
- arrow_array::make_array(data),
- None,
- )?))
-}
-
// Create new offsets that are euqiavlent to `flatten` the array.
fn get_offsets_for_flatten<O: OffsetSizeTrait>(
offsets: OffsetBuffer<O>,
diff --git a/datafusion/functions-array/src/lib.rs
b/datafusion/functions-array/src/lib.rs
index 4a7bb3fda9..feecd18c2e 100644
--- a/datafusion/functions-array/src/lib.rs
+++ b/datafusion/functions-array/src/lib.rs
@@ -29,6 +29,7 @@
pub mod macros;
mod array_has;
+mod cardinality;
mod concat;
mod core;
mod empty;
@@ -41,6 +42,7 @@ mod range;
mod remove;
mod repeat;
mod replace;
+mod resize;
mod reverse;
mod rewrite;
mod set_ops;
@@ -60,6 +62,7 @@ pub mod expr_fn {
pub use super::array_has::array_has;
pub use super::array_has::array_has_all;
pub use super::array_has::array_has_any;
+ pub use super::cardinality::cardinality;
pub use super::concat::array_append;
pub use super::concat::array_concat;
pub use super::concat::array_prepend;
@@ -82,6 +85,7 @@ pub mod expr_fn {
pub use super::replace::array_replace;
pub use super::replace::array_replace_all;
pub use super::replace::array_replace_n;
+ pub use super::resize::array_resize;
pub use super::reverse::array_reverse;
pub use super::set_ops::array_distinct;
pub use super::set_ops::array_intersect;
@@ -91,8 +95,6 @@ pub mod expr_fn {
pub use super::string::string_to_array;
pub use super::udf::array_dims;
pub use super::udf::array_ndims;
- pub use super::udf::array_resize;
- pub use super::udf::cardinality;
pub use super::udf::flatten;
}
@@ -104,7 +106,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) ->
Result<()> {
range::range_udf(),
range::gen_series_udf(),
udf::array_dims_udf(),
- udf::cardinality_udf(),
+ cardinality::cardinality_udf(),
udf::array_ndims_udf(),
concat::array_append_udf(),
concat::array_prepend_udf(),
@@ -123,7 +125,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) ->
Result<()> {
udf::flatten_udf(),
sort::array_sort_udf(),
repeat::array_repeat_udf(),
- udf::array_resize_udf(),
+ resize::array_resize_udf(),
reverse::array_reverse_udf(),
set_ops::array_distinct_udf(),
set_ops::array_intersect_udf(),
diff --git a/datafusion/functions-array/src/resize.rs
b/datafusion/functions-array/src/resize.rs
new file mode 100644
index 0000000000..f3996110f9
--- /dev/null
+++ b/datafusion/functions-array/src/resize.rs
@@ -0,0 +1,179 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! [`ScalarUDFImpl`] definitions for array_resize function.
+
+use crate::utils::make_scalar_function;
+use arrow::array::{Capacities, MutableArrayData};
+use arrow_array::{ArrayRef, GenericListArray, Int64Array, OffsetSizeTrait};
+use arrow_buffer::{ArrowNativeType, OffsetBuffer};
+use arrow_schema::DataType::{FixedSizeList, LargeList, List};
+use arrow_schema::{DataType, FieldRef};
+use datafusion_common::cast::{as_int64_array, as_large_list_array,
as_list_array};
+use datafusion_common::{exec_err, internal_datafusion_err, ScalarValue};
+use datafusion_expr::expr::ScalarFunction;
+use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature,
Volatility};
+use std::any::Any;
+use std::sync::Arc;
+
+make_udf_function!(
+ ArrayResize,
+ array_resize,
+ array size value,
+ "returns an array with the specified size filled with the given value.",
+ array_resize_udf
+);
+
+#[derive(Debug)]
+pub(super) struct ArrayResize {
+ signature: Signature,
+ aliases: Vec<String>,
+}
+
+impl ArrayResize {
+ pub fn new() -> Self {
+ Self {
+ signature: Signature::variadic_any(Volatility::Immutable),
+ aliases: vec!["array_resize".to_string(),
"list_resize".to_string()],
+ }
+ }
+}
+
+impl ScalarUDFImpl for ArrayResize {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+ fn name(&self) -> &str {
+ "array_resize"
+ }
+
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+
+ fn return_type(&self, arg_types: &[DataType]) ->
datafusion_common::Result<DataType> {
+ match &arg_types[0] {
+ List(field) | FixedSizeList(field, _) => Ok(List(field.clone())),
+ LargeList(field) => Ok(LargeList(field.clone())),
+ _ => exec_err!(
+ "Not reachable, data_type should be List, LargeList or
FixedSizeList"
+ ),
+ }
+ }
+
+ fn invoke(&self, args: &[ColumnarValue]) ->
datafusion_common::Result<ColumnarValue> {
+ make_scalar_function(array_resize_inner)(args)
+ }
+
+ fn aliases(&self) -> &[String] {
+ &self.aliases
+ }
+}
+
+/// array_resize SQL function
+pub fn array_resize_inner(arg: &[ArrayRef]) ->
datafusion_common::Result<ArrayRef> {
+ if arg.len() < 2 || arg.len() > 3 {
+ return exec_err!("array_resize needs two or three arguments");
+ }
+
+ let new_len = as_int64_array(&arg[1])?;
+ let new_element = if arg.len() == 3 {
+ Some(arg[2].clone())
+ } else {
+ None
+ };
+
+ match &arg[0].data_type() {
+ DataType::List(field) => {
+ let array = as_list_array(&arg[0])?;
+ general_list_resize::<i32>(array, new_len, field, new_element)
+ }
+ DataType::LargeList(field) => {
+ let array = as_large_list_array(&arg[0])?;
+ general_list_resize::<i64>(array, new_len, field, new_element)
+ }
+ array_type => exec_err!("array_resize does not support type
'{array_type:?}'."),
+ }
+}
+
+/// array_resize keep the original array and append the default element to the
end
+fn general_list_resize<O: OffsetSizeTrait>(
+ array: &GenericListArray<O>,
+ count_array: &Int64Array,
+ field: &FieldRef,
+ default_element: Option<ArrayRef>,
+) -> datafusion_common::Result<ArrayRef>
+where
+ O: TryInto<i64>,
+{
+ let data_type = array.value_type();
+
+ let values = array.values();
+ let original_data = values.to_data();
+
+ // create default element array
+ let default_element = if let Some(default_element) = default_element {
+ default_element
+ } else {
+ let null_scalar = ScalarValue::try_from(&data_type)?;
+ null_scalar.to_array_of_size(original_data.len())?
+ };
+ let default_value_data = default_element.to_data();
+
+ // create a mutable array to store the original data
+ let capacity = Capacities::Array(original_data.len() +
default_value_data.len());
+ let mut offsets = vec![O::usize_as(0)];
+ let mut mutable = MutableArrayData::with_capacities(
+ vec![&original_data, &default_value_data],
+ false,
+ capacity,
+ );
+
+ for (row_index, offset_window) in array.offsets().windows(2).enumerate() {
+ let count = count_array.value(row_index).to_usize().ok_or_else(|| {
+ internal_datafusion_err!("array_resize: failed to convert size to
usize")
+ })?;
+ let count = O::usize_as(count);
+ let start = offset_window[0];
+ if start + count > offset_window[1] {
+ let extra_count =
+ (start + count - offset_window[1]).try_into().map_err(|_| {
+ internal_datafusion_err!(
+ "array_resize: failed to convert size to i64"
+ )
+ })?;
+ let end = offset_window[1];
+ mutable.extend(0, (start).to_usize().unwrap(),
(end).to_usize().unwrap());
+ // append default element
+ for _ in 0..extra_count {
+ mutable.extend(1, row_index, row_index + 1);
+ }
+ } else {
+ let end = start + count;
+ mutable.extend(0, (start).to_usize().unwrap(),
(end).to_usize().unwrap());
+ };
+ offsets.push(offsets[row_index] + count);
+ }
+
+ let data = mutable.freeze();
+ Ok(Arc::new(GenericListArray::<O>::try_new(
+ field.clone(),
+ OffsetBuffer::<O>::new(offsets.into()),
+ arrow_array::make_array(data),
+ None,
+ )?))
+}
diff --git a/datafusion/functions-array/src/udf.rs
b/datafusion/functions-array/src/udf.rs
index 9cbcf0a923..bdc11155b6 100644
--- a/datafusion/functions-array/src/udf.rs
+++ b/datafusion/functions-array/src/udf.rs
@@ -85,116 +85,6 @@ impl ScalarUDFImpl for ArrayDims {
}
}
-make_udf_function!(
- ArrayResize,
- array_resize,
- array size value,
- "returns an array with the specified size filled with the given value.",
- array_resize_udf
-);
-
-#[derive(Debug)]
-pub(super) struct ArrayResize {
- signature: Signature,
- aliases: Vec<String>,
-}
-
-impl ArrayResize {
- pub fn new() -> Self {
- Self {
- signature: Signature::variadic_any(Volatility::Immutable),
- aliases: vec!["array_resize".to_string(),
"list_resize".to_string()],
- }
- }
-}
-
-impl ScalarUDFImpl for ArrayResize {
- fn as_any(&self) -> &dyn Any {
- self
- }
- fn name(&self) -> &str {
- "array_resize"
- }
-
- fn signature(&self) -> &Signature {
- &self.signature
- }
-
- fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
- use DataType::*;
- match &arg_types[0] {
- List(field) | FixedSizeList(field, _) => Ok(List(field.clone())),
- LargeList(field) => Ok(LargeList(field.clone())),
- _ => exec_err!(
- "Not reachable, data_type should be List, LargeList or
FixedSizeList"
- ),
- }
- }
-
- fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
- let args = ColumnarValue::values_to_arrays(args)?;
- crate::kernels::array_resize(&args).map(ColumnarValue::Array)
- }
-
- fn aliases(&self) -> &[String] {
- &self.aliases
- }
-}
-
-make_udf_function!(
- Cardinality,
- cardinality,
- array,
- "returns the total number of elements in the array.",
- cardinality_udf
-);
-
-impl Cardinality {
- pub fn new() -> Self {
- Self {
- signature: Signature::array(Volatility::Immutable),
- aliases: vec![String::from("cardinality")],
- }
- }
-}
-
-#[derive(Debug)]
-pub(super) struct Cardinality {
- signature: Signature,
- aliases: Vec<String>,
-}
-impl ScalarUDFImpl for Cardinality {
- fn as_any(&self) -> &dyn Any {
- self
- }
- fn name(&self) -> &str {
- "cardinality"
- }
-
- fn signature(&self) -> &Signature {
- &self.signature
- }
-
- fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
- use DataType::*;
- Ok(match arg_types[0] {
- List(_) | LargeList(_) | FixedSizeList(_, _) => UInt64,
- _ => {
- return plan_err!("The cardinality function can only accept
List/LargeList/FixedSizeList.");
- }
- })
- }
-
- fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
- let args = ColumnarValue::values_to_arrays(args)?;
- crate::kernels::cardinality(&args).map(ColumnarValue::Array)
- }
-
- fn aliases(&self) -> &[String] {
- &self.aliases
- }
-}
-
make_udf_function!(
ArrayNdims,
array_ndims,
diff --git a/datafusion/functions-array/src/utils.rs
b/datafusion/functions-array/src/utils.rs
index c0f7627d2a..d86e4fe2ab 100644
--- a/datafusion/functions-array/src/utils.rs
+++ b/datafusion/functions-array/src/utils.rs
@@ -22,15 +22,30 @@ use std::sync::Arc;
use arrow::{array::ArrayRef, datatypes::DataType};
use arrow_array::{
- Array, BooleanArray, GenericListArray, OffsetSizeTrait, Scalar,
UInt32Array,
+ Array, BooleanArray, GenericListArray, ListArray, OffsetSizeTrait, Scalar,
+ UInt32Array,
};
use arrow_buffer::OffsetBuffer;
use arrow_schema::Field;
use datafusion_common::cast::{as_large_list_array, as_list_array};
use datafusion_common::{exec_err, plan_err, Result, ScalarValue};
+use core::any::type_name;
+use datafusion_common::DataFusionError;
use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation};
+macro_rules! downcast_arg {
+ ($ARG:expr, $ARRAY_TYPE:ident) => {{
+ $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
+ DataFusionError::Internal(format!(
+ "could not cast to {}",
+ type_name::<$ARRAY_TYPE>()
+ ))
+ })?
+ }};
+}
+pub(crate) use downcast_arg;
+
pub(crate) fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> {
let data_type = args[0].data_type();
if !args.iter().all(|arg| {
@@ -214,17 +229,29 @@ pub(crate) fn compare_element_to_list(
Ok(res)
}
-macro_rules! downcast_arg {
- ($ARG:expr, $ARRAY_TYPE:ident) => {{
- $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
- DataFusionError::Internal(format!(
- "could not cast to {}",
- type_name::<$ARRAY_TYPE>()
- ))
- })?
- }};
+/// Returns the length of each array dimension
+pub(crate) fn compute_array_dims(
+ arr: Option<ArrayRef>,
+) -> Result<Option<Vec<Option<u64>>>> {
+ let mut value = match arr {
+ Some(arr) => arr,
+ None => return Ok(None),
+ };
+ if value.is_empty() {
+ return Ok(None);
+ }
+ let mut res = vec![Some(value.len() as u64)];
+
+ loop {
+ match value.data_type() {
+ DataType::List(..) => {
+ value = downcast_arg!(value, ListArray).value(0);
+ res.push(Some(value.len() as u64));
+ }
+ _ => return Ok(Some(res)),
+ }
+ }
}
-pub(crate) use downcast_arg;
#[cfg(test)]
mod tests {