This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 93af4401d7 Minor: Support `nulls` in  `array_replace`, avoid a copy 
(#8054)
93af4401d7 is described below

commit 93af4401d77e9761ca3d187cdc56aa245f7aa7aa
Author: Andrew Lamb <[email protected]>
AuthorDate: Thu Nov 9 12:45:16 2023 -0500

    Minor: Support `nulls` in  `array_replace`, avoid a copy (#8054)
    
    * Minor: clean up array_replace
    
    * null test
    
    * remove println
    
    * Fix doc test
    
    * port test to sqllogictest
    
    * Use not_distinct
    
    * Apply suggestions from code review
    
    Co-authored-by: jakevin <[email protected]>
    
    ---------
    
    Co-authored-by: jakevin <[email protected]>
---
 datafusion/physical-expr/src/array_expressions.rs | 151 +++++++++++++---------
 datafusion/sqllogictest/test_files/array.slt      |  31 +++++
 2 files changed, 119 insertions(+), 63 deletions(-)

diff --git a/datafusion/physical-expr/src/array_expressions.rs 
b/datafusion/physical-expr/src/array_expressions.rs
index 64550aabf4..deb4372baa 100644
--- a/datafusion/physical-expr/src/array_expressions.rs
+++ b/datafusion/physical-expr/src/array_expressions.rs
@@ -1211,119 +1211,144 @@ array_removement_function!(
     "Array_remove_all SQL function"
 );
 
-fn general_replace(args: &[ArrayRef], arr_n: Vec<i64>) -> Result<ArrayRef> {
-    let list_array = as_list_array(&args[0])?;
-    let from_array = &args[1];
-    let to_array = &args[2];
-
+/// For each element of `list_array[i]`, replaces up to `arr_n[i]`  occurences
+/// of `from_array[i]`, `to_array[i]`.
+///
+/// The type of each **element** in `list_array` must be the same as the type 
of
+/// `from_array` and `to_array`. This function also handles nested arrays
+/// ([`ListArray`] of [`ListArray`]s)
+///
+/// For example, when called to replace a list array (where each element is a
+/// list of int32s, the second and third argument are int32 arrays, and the
+/// fourth argument is the number of occurrences to replace
+///
+/// ```text
+/// general_replace(
+///   [1, 2, 3, 2], 2, 10, 1    ==> [1, 10, 3, 2]   (only the first 2 is 
replaced)
+///   [4, 5, 6, 5], 5, 20, 2    ==> [4, 20, 6, 20]  (both 5s are replaced)
+/// )
+/// ```
+fn general_replace(
+    list_array: &ListArray,
+    from_array: &ArrayRef,
+    to_array: &ArrayRef,
+    arr_n: Vec<i64>,
+) -> Result<ArrayRef> {
+    // Build up the offsets for the final output array
     let mut offsets: Vec<i32> = vec![0];
     let data_type = list_array.value_type();
-    let mut values = new_empty_array(&data_type);
+    let mut new_values = vec![];
 
-    for (row_index, (arr, n)) in 
list_array.iter().zip(arr_n.iter()).enumerate() {
+    // n is the number of elements to replace in this row
+    for (row_index, (list_array_row, n)) in
+        list_array.iter().zip(arr_n.iter()).enumerate()
+    {
         let last_offset: i32 = offsets
             .last()
             .copied()
             .ok_or_else(|| internal_datafusion_err!("offsets should not be 
empty"))?;
-        match arr {
-            Some(arr) => {
-                let indices = UInt32Array::from(vec![row_index as u32]);
-                let from_arr = arrow::compute::take(from_array, &indices, 
None)?;
 
-                let eq_array = match from_arr.data_type() {
-                    // arrow_ord::cmp_eq does not support ListArray, so we 
need to compare it by loop
+        match list_array_row {
+            Some(list_array_row) => {
+                let indices = UInt32Array::from(vec![row_index as u32]);
+                let from_array_row = arrow::compute::take(from_array, 
&indices, None)?;
+                // Compute all positions in list_row_array (that is itself an
+                // array) that are equal to `from_array_row`
+                let eq_array = match from_array_row.data_type() {
+                    // arrow_ord::cmp::eq does not support ListArray, so we 
need to compare it by loop
                     DataType::List(_) => {
-                        let from_a = as_list_array(&from_arr)?.value(0);
-                        let list_arr = as_list_array(&arr)?;
+                        // compare each element of the from array
+                        let from_array_row_inner =
+                            as_list_array(&from_array_row)?.value(0);
+                        let list_array_row_inner = 
as_list_array(&list_array_row)?;
 
-                        let mut bool_values = vec![];
-                        for arr in list_arr.iter() {
-                            if let Some(a) = arr {
-                                bool_values.push(Some(a.eq(&from_a)));
-                            } else {
-                                return internal_err!(
-                                    "Null value is not supported in 
array_replace"
-                                );
-                            }
-                        }
-                        BooleanArray::from(bool_values)
+                        list_array_row_inner
+                            .iter()
+                            // compare element by element the current row of 
list_array
+                            .map(|row| row.map(|row| 
row.eq(&from_array_row_inner)))
+                            .collect::<BooleanArray>()
                     }
                     _ => {
-                        let from_arr = Scalar::new(from_arr);
-                        arrow_ord::cmp::eq(&arr, &from_arr)?
+                        let from_arr = Scalar::new(from_array_row);
+                        // use not_distinct so NULL = NULL
+                        arrow_ord::cmp::not_distinct(&list_array_row, 
&from_arr)?
                     }
                 };
 
                 // Use MutableArrayData to build the replaced array
+                let original_data = list_array_row.to_data();
+                let to_data = to_array.to_data();
+                let capacity = Capacities::Array(original_data.len() + 
to_data.len());
+
                 // First array is the original array, second array is the 
element to replace with.
-                let arrays = vec![arr, to_array.clone()];
-                let arrays_data = arrays
-                    .iter()
-                    .map(|a| a.to_data())
-                    .collect::<Vec<ArrayData>>();
-                let arrays_data = 
arrays_data.iter().collect::<Vec<&ArrayData>>();
-
-                let arrays = arrays
-                    .iter()
-                    .map(|arr| arr.as_ref())
-                    .collect::<Vec<&dyn Array>>();
-                let capacity = Capacities::Array(arrays.iter().map(|a| 
a.len()).sum());
-
-                let mut mutable =
-                    MutableArrayData::with_capacities(arrays_data, false, 
capacity);
+                let mut mutable = MutableArrayData::with_capacities(
+                    vec![&original_data, &to_data],
+                    false,
+                    capacity,
+                );
+                let original_idx = 0;
+                let replace_idx = 1;
 
                 let mut counter = 0;
                 for (i, to_replace) in eq_array.iter().enumerate() {
-                    if let Some(to_replace) = to_replace {
-                        if to_replace {
-                            mutable.extend(1, row_index, row_index + 1);
-                            counter += 1;
-                            if counter == *n {
-                                // extend the rest of the array
-                                mutable.extend(0, i + 1, eq_array.len());
-                                break;
-                            }
-                        } else {
-                            mutable.extend(0, i, i + 1);
+                    if let Some(true) = to_replace {
+                        mutable.extend(replace_idx, row_index, row_index + 1);
+                        counter += 1;
+                        if counter == *n {
+                            // copy original data for any matches past n
+                            mutable.extend(original_idx, i + 1, 
eq_array.len());
+                            break;
                         }
                     } else {
-                        return internal_err!("eq_array should not contain 
None");
+                        // copy original data for false / null matches
+                        mutable.extend(original_idx, i, i + 1);
                     }
                 }
 
                 let data = mutable.freeze();
                 let replaced_array = arrow_array::make_array(data);
 
-                let v = arrow::compute::concat(&[&values, &replaced_array])?;
-                values = v;
                 offsets.push(last_offset + replaced_array.len() as i32);
+                new_values.push(replaced_array);
             }
             None => {
+                // Null element results in a null row (no new offsets)
                 offsets.push(last_offset);
             }
         }
     }
 
+    let values = if new_values.is_empty() {
+        new_empty_array(&data_type)
+    } else {
+        let new_values: Vec<_> = new_values.iter().map(|a| 
a.as_ref()).collect();
+        arrow::compute::concat(&new_values)?
+    };
+
     Ok(Arc::new(ListArray::try_new(
         Arc::new(Field::new("item", data_type, true)),
         OffsetBuffer::new(offsets.into()),
         values,
-        None,
+        list_array.nulls().cloned(),
     )?))
 }
 
 pub fn array_replace(args: &[ArrayRef]) -> Result<ArrayRef> {
-    general_replace(args, vec![1; args[0].len()])
+    // replace at most one occurence for each element
+    let arr_n = vec![1; args[0].len()];
+    general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n)
 }
 
 pub fn array_replace_n(args: &[ArrayRef]) -> Result<ArrayRef> {
-    let arr = as_int64_array(&args[3])?;
-    let arr_n = arr.values().to_vec();
-    general_replace(args, arr_n)
+    // replace the specified number of occurences
+    let arr_n = as_int64_array(&args[3])?.values().to_vec();
+    general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n)
 }
 
 pub fn array_replace_all(args: &[ArrayRef]) -> Result<ArrayRef> {
-    general_replace(args, vec![i64::MAX; args[0].len()])
+    // replace all occurences (up to "i64::MAX")
+    let arr_n = vec![i64::MAX; args[0].len()];
+    general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n)
 }
 
 macro_rules! to_string {
diff --git a/datafusion/sqllogictest/test_files/array.slt 
b/datafusion/sqllogictest/test_files/array.slt
index 85218efb5e..c57369c167 100644
--- a/datafusion/sqllogictest/test_files/array.slt
+++ b/datafusion/sqllogictest/test_files/array.slt
@@ -1720,6 +1720,37 @@ select array_replace_all(make_array([1, 2, 3], [4, 5, 
6], [4, 5, 6], [10, 11, 12
 [[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], 
[10, 11, 12], [28, 29, 30], [28, 29, 30], [28, 29, 30], [28, 29, 30], [22, 23, 
24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], 
[25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[11, 12, 
13], [11, 12, 13], [11, 12, 13], [22, 23, 24], [11, 12, 13], [25, 26, 27], [11, 
12, 13], [22, 23, 24], [11, 12, 13], [11, 12, 13]]
 [[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], 
[10, 11, 12], [19, 20, 21], [19, 20, 21], [37, 38, 39], [19, 20, 21], [22, 23, 
24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], 
[34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[11, 12, 
13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11, 
12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13]]
 
+# array_replace with null handling
+
+statement ok
+create table t as values
+  (make_array(3, 1, NULL, 3), 3,    4,    2),
+  (make_array(3, 1, NULL, 3), NULL, 5,    2),
+  (NULL,                            3,    2,    1),
+  (make_array(3, 1, 3),             3,    NULL, 1)
+;
+
+
+# ([3, 1, NULL, 3], 3,    4,    2)  => [4, 1, NULL, 4] NULL not matched
+# ([3, 1, NULL, 3], NULL, 5,    2)  => [3, 1, NULL, 3] NULL is replaced with 5
+# ([NULL],          3,    2,    1)  => NULL
+# ([3, 1, 3],       3,    NULL, 1)  => [NULL, 1 3]
+
+query ?III?
+select column1, column2, column3, column4, array_replace_n(column1, column2, 
column3, column4) from t;
+----
+[3, 1, , 3] 3 4 2 [4, 1, , 4]
+[3, 1, , 3] NULL 5 2 [3, 1, 5, 3]
+NULL 3 2 1 NULL
+[3, 1, 3] 3 NULL 1 [, 1, 3]
+
+
+
+statement ok
+drop table t;
+
+
+
 ## array_to_string (aliases: `list_to_string`, `array_join`, `list_join`)
 
 # array_to_string scalar function #1

Reply via email to