This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new f5d10e55d5 Rewrite `array_ndims` to fix List(Null) handling (#8320)
f5d10e55d5 is described below

commit f5d10e55d575e1eec58b993cab2d8a7ca2370ff9
Author: Jay Zhan <[email protected]>
AuthorDate: Sat Dec 2 06:28:32 2023 +0800

    Rewrite `array_ndims` to fix List(Null) handling (#8320)
    
    * done
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * add more test
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * cleanup
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    ---------
    
    Signed-off-by: jayzhan211 <[email protected]>
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/common/src/utils.rs                    | 32 ++++++++++
 datafusion/physical-expr/src/array_expressions.rs | 76 ++++++++---------------
 datafusion/sqllogictest/test_files/array.slt      | 42 +++++++++++--
 3 files changed, 97 insertions(+), 53 deletions(-)

diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs
index 12d4f516b4..7f2dc61c07 100644
--- a/datafusion/common/src/utils.rs
+++ b/datafusion/common/src/utils.rs
@@ -26,6 +26,7 @@ use arrow::compute::{partition, SortColumn, SortOptions};
 use arrow::datatypes::{Field, SchemaRef, UInt32Type};
 use arrow::record_batch::RecordBatch;
 use arrow_array::{Array, LargeListArray, ListArray};
+use arrow_schema::DataType;
 use sqlparser::ast::Ident;
 use sqlparser::dialect::GenericDialect;
 use sqlparser::parser::Parser;
@@ -402,6 +403,37 @@ pub fn arrays_into_list_array(
     ))
 }
 
+/// Get the base type of a data type.
+///
+/// Example
+/// ```
+/// use arrow::datatypes::{DataType, Field};
+/// use datafusion_common::utils::base_type;
+/// use std::sync::Arc;
+///
+/// let data_type = DataType::List(Arc::new(Field::new("item", 
DataType::Int32, true)));
+/// assert_eq!(base_type(&data_type), DataType::Int32);
+///
+/// let data_type = DataType::Int32;
+/// assert_eq!(base_type(&data_type), DataType::Int32);
+/// ```
+pub fn base_type(data_type: &DataType) -> DataType {
+    if let DataType::List(field) = data_type {
+        base_type(field.data_type())
+    } else {
+        data_type.to_owned()
+    }
+}
+
+/// Compute the number of dimensions in a list data type.
+pub fn list_ndims(data_type: &DataType) -> u64 {
+    if let DataType::List(field) = data_type {
+        1 + list_ndims(field.data_type())
+    } else {
+        0
+    }
+}
+
 /// An extension trait for smart pointers. Provides an interface to get a
 /// raw pointer to the data (with metadata stripped away).
 ///
diff --git a/datafusion/physical-expr/src/array_expressions.rs 
b/datafusion/physical-expr/src/array_expressions.rs
index a36f485d7b..7059c6a9f3 100644
--- a/datafusion/physical-expr/src/array_expressions.rs
+++ b/datafusion/physical-expr/src/array_expressions.rs
@@ -33,7 +33,7 @@ use datafusion_common::cast::{
     as_generic_list_array, as_generic_string_array, as_int64_array, 
as_list_array,
     as_null_array, as_string_array,
 };
-use datafusion_common::utils::array_into_list_array;
+use datafusion_common::utils::{array_into_list_array, list_ndims};
 use datafusion_common::{
     exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err,
     DataFusionError, Result,
@@ -103,6 +103,7 @@ fn compare_element_to_list(
 ) -> Result<BooleanArray> {
     let indices = UInt32Array::from(vec![row_index as u32]);
     let element_array_row = arrow::compute::take(element_array, &indices, 
None)?;
+
     // Compute all positions in list_row_array (that is itself an
     // array) that are equal to `from_array_row`
     let res = match element_array_row.data_type() {
@@ -176,35 +177,6 @@ fn compute_array_length(
     }
 }
 
-/// Returns the dimension of the array
-fn compute_array_ndims(arr: Option<ArrayRef>) -> Result<Option<u64>> {
-    Ok(compute_array_ndims_with_datatype(arr)?.0)
-}
-
-/// Returns the dimension and the datatype of elements of the array
-fn compute_array_ndims_with_datatype(
-    arr: Option<ArrayRef>,
-) -> Result<(Option<u64>, DataType)> {
-    let mut res: u64 = 1;
-    let mut value = match arr {
-        Some(arr) => arr,
-        None => return Ok((None, DataType::Null)),
-    };
-    if value.is_empty() {
-        return Ok((None, DataType::Null));
-    }
-
-    loop {
-        match value.data_type() {
-            DataType::List(..) => {
-                value = downcast_arg!(value, ListArray).value(0);
-                res += 1;
-            }
-            data_type => return Ok((Some(res), data_type.clone())),
-        }
-    }
-}
-
 /// Returns the length of each array dimension
 fn compute_array_dims(arr: Option<ArrayRef>) -> 
Result<Option<Vec<Option<u64>>>> {
     let mut value = match arr {
@@ -825,10 +797,7 @@ pub fn array_prepend(args: &[ArrayRef]) -> 
Result<ArrayRef> {
 fn align_array_dimensions(args: Vec<ArrayRef>) -> Result<Vec<ArrayRef>> {
     let args_ndim = args
         .iter()
-        .map(|arg| compute_array_ndims(Some(arg.to_owned())))
-        .collect::<Result<Vec<_>>>()?
-        .into_iter()
-        .map(|x| x.unwrap_or(0))
+        .map(|arg| datafusion_common::utils::list_ndims(arg.data_type()))
         .collect::<Vec<_>>();
     let max_ndim = args_ndim.iter().max().unwrap_or(&0);
 
@@ -919,6 +888,7 @@ fn concat_internal(args: &[ArrayRef]) -> Result<ArrayRef> {
         Arc::new(compute::concat(elements.as_slice())?),
         Some(NullBuffer::new(buffer)),
     );
+
     Ok(Arc::new(list_arr))
 }
 
@@ -926,11 +896,11 @@ fn concat_internal(args: &[ArrayRef]) -> Result<ArrayRef> 
{
 pub fn array_concat(args: &[ArrayRef]) -> Result<ArrayRef> {
     let mut new_args = vec![];
     for arg in args {
-        let (ndim, lower_data_type) =
-            compute_array_ndims_with_datatype(Some(arg.clone()))?;
-        if ndim.is_none() || ndim == Some(1) {
-            return not_impl_err!("Array is not type '{lower_data_type:?}'.");
-        } else if !lower_data_type.equals_datatype(&DataType::Null) {
+        let ndim = list_ndims(arg.data_type());
+        let base_type = datafusion_common::utils::base_type(arg.data_type());
+        if ndim == 0 {
+            return not_impl_err!("Array is not type '{base_type:?}'.");
+        } else if !base_type.eq(&DataType::Null) {
             new_args.push(arg.clone());
         }
     }
@@ -1765,14 +1735,22 @@ pub fn array_dims(args: &[ArrayRef]) -> 
Result<ArrayRef> {
 
 /// Array_ndims SQL function
 pub fn array_ndims(args: &[ArrayRef]) -> Result<ArrayRef> {
-    let list_array = as_list_array(&args[0])?;
+    if let Some(list_array) = args[0].as_list_opt::<i32>() {
+        let ndims = 
datafusion_common::utils::list_ndims(list_array.data_type());
 
-    let result = list_array
-        .iter()
-        .map(compute_array_ndims)
-        .collect::<Result<UInt64Array>>()?;
+        let mut data = vec![];
+        for arr in list_array.iter() {
+            if arr.is_some() {
+                data.push(Some(ndims))
+            } else {
+                data.push(None)
+            }
+        }
 
-    Ok(Arc::new(result) as ArrayRef)
+        Ok(Arc::new(UInt64Array::from(data)) as ArrayRef)
+    } else {
+        Ok(Arc::new(UInt64Array::from(vec![0; args[0].len()])) as ArrayRef)
+    }
 }
 
 /// Array_has SQL function
@@ -2034,10 +2012,10 @@ mod tests {
                 .unwrap();
 
         let expected = as_list_array(&array2d_1).unwrap();
-        let expected_dim = 
compute_array_ndims(Some(array2d_1.to_owned())).unwrap();
+        let expected_dim = 
datafusion_common::utils::list_ndims(array2d_1.data_type());
         assert_ne!(as_list_array(&res[0]).unwrap(), expected);
         assert_eq!(
-            compute_array_ndims(Some(res[0].clone())).unwrap(),
+            datafusion_common::utils::list_ndims(res[0].data_type()),
             expected_dim
         );
 
@@ -2047,10 +2025,10 @@ mod tests {
             align_array_dimensions(vec![array1d_1, 
Arc::new(array3d_2.clone())]).unwrap();
 
         let expected = as_list_array(&array3d_1).unwrap();
-        let expected_dim = 
compute_array_ndims(Some(array3d_1.to_owned())).unwrap();
+        let expected_dim = 
datafusion_common::utils::list_ndims(array3d_1.data_type());
         assert_ne!(as_list_array(&res[0]).unwrap(), expected);
         assert_eq!(
-            compute_array_ndims(Some(res[0].clone())).unwrap(),
+            datafusion_common::utils::list_ndims(res[0].data_type()),
             expected_dim
         );
     }
diff --git a/datafusion/sqllogictest/test_files/array.slt 
b/datafusion/sqllogictest/test_files/array.slt
index 3b45d995e1..092bc697a1 100644
--- a/datafusion/sqllogictest/test_files/array.slt
+++ b/datafusion/sqllogictest/test_files/array.slt
@@ -2479,10 +2479,44 @@ NULL [3] [4]
 ## array_ndims (aliases: `list_ndims`)
 
 # array_ndims scalar function #1
+
 query III
-select array_ndims(make_array(1, 2, 3)), array_ndims(make_array([1, 2], [3, 
4])), array_ndims(make_array([[[[1], [2]]]]));
+select 
+  array_ndims(1),
+  array_ndims(null),
+  array_ndims([2, 3]);
 ----
-1 2 5
+0 0 1
+
+statement ok
+CREATE TABLE array_ndims_table
+AS VALUES
+  (1, [1, 2, 3], [[7]], [[[[[10]]]]]),
+  (2, [4, 5], [[8]], [[[[[10]]]]]),
+  (null, [6], [[9]], [[[[[10]]]]]),
+  (3, [6], [[9]], [[[[[10]]]]])
+;
+
+query IIII
+select 
+  array_ndims(column1),
+  array_ndims(column2),
+  array_ndims(column3),
+  array_ndims(column4)
+from array_ndims_table;
+----
+0 1 2 5
+0 1 2 5
+0 1 2 5
+0 1 2 5
+
+statement ok
+drop table array_ndims_table;
+
+query I
+select array_ndims(arrow_cast([null], 'List(List(List(Int64)))'));
+----
+3
 
 # array_ndims scalar function #2
 query II
@@ -2494,7 +2528,7 @@ select 
array_ndims(array_repeat(array_repeat(array_repeat(1, 3), 2), 1)), array_
 query II
 select array_ndims(make_array()), array_ndims(make_array(make_array()))
 ----
-NULL 2
+1 2
 
 # list_ndims scalar function #4 (function alias `array_ndims`)
 query III
@@ -2505,7 +2539,7 @@ select list_ndims(make_array(1, 2, 3)), 
list_ndims(make_array([1, 2], [3, 4])),
 query II
 select array_ndims(make_array()), array_ndims(make_array(make_array()))
 ----
-NULL 2
+1 2
 
 # array_ndims with columns
 query III

Reply via email to