This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 93af4401d7 Minor: Support `nulls` in `array_replace`, avoid a copy
(#8054)
93af4401d7 is described below
commit 93af4401d77e9761ca3d187cdc56aa245f7aa7aa
Author: Andrew Lamb <[email protected]>
AuthorDate: Thu Nov 9 12:45:16 2023 -0500
Minor: Support `nulls` in `array_replace`, avoid a copy (#8054)
* Minor: clean up array_replace
* null test
* remove println
* Fix doc test
* port test to sqllogictest
* Use not_distinct
* Apply suggestions from code review
Co-authored-by: jakevin <[email protected]>
---------
Co-authored-by: jakevin <[email protected]>
---
datafusion/physical-expr/src/array_expressions.rs | 151 +++++++++++++---------
datafusion/sqllogictest/test_files/array.slt | 31 +++++
2 files changed, 119 insertions(+), 63 deletions(-)
diff --git a/datafusion/physical-expr/src/array_expressions.rs
b/datafusion/physical-expr/src/array_expressions.rs
index 64550aabf4..deb4372baa 100644
--- a/datafusion/physical-expr/src/array_expressions.rs
+++ b/datafusion/physical-expr/src/array_expressions.rs
@@ -1211,119 +1211,144 @@ array_removement_function!(
"Array_remove_all SQL function"
);
-fn general_replace(args: &[ArrayRef], arr_n: Vec<i64>) -> Result<ArrayRef> {
- let list_array = as_list_array(&args[0])?;
- let from_array = &args[1];
- let to_array = &args[2];
-
+/// For each element of `list_array[i]`, replaces up to `arr_n[i]` occurences
+/// of `from_array[i]`, `to_array[i]`.
+///
+/// The type of each **element** in `list_array` must be the same as the type
of
+/// `from_array` and `to_array`. This function also handles nested arrays
+/// ([`ListArray`] of [`ListArray`]s)
+///
+/// For example, when called to replace a list array (where each element is a
+/// list of int32s, the second and third argument are int32 arrays, and the
+/// fourth argument is the number of occurrences to replace
+///
+/// ```text
+/// general_replace(
+/// [1, 2, 3, 2], 2, 10, 1 ==> [1, 10, 3, 2] (only the first 2 is
replaced)
+/// [4, 5, 6, 5], 5, 20, 2 ==> [4, 20, 6, 20] (both 5s are replaced)
+/// )
+/// ```
+fn general_replace(
+ list_array: &ListArray,
+ from_array: &ArrayRef,
+ to_array: &ArrayRef,
+ arr_n: Vec<i64>,
+) -> Result<ArrayRef> {
+ // Build up the offsets for the final output array
let mut offsets: Vec<i32> = vec![0];
let data_type = list_array.value_type();
- let mut values = new_empty_array(&data_type);
+ let mut new_values = vec![];
- for (row_index, (arr, n)) in
list_array.iter().zip(arr_n.iter()).enumerate() {
+ // n is the number of elements to replace in this row
+ for (row_index, (list_array_row, n)) in
+ list_array.iter().zip(arr_n.iter()).enumerate()
+ {
let last_offset: i32 = offsets
.last()
.copied()
.ok_or_else(|| internal_datafusion_err!("offsets should not be
empty"))?;
- match arr {
- Some(arr) => {
- let indices = UInt32Array::from(vec![row_index as u32]);
- let from_arr = arrow::compute::take(from_array, &indices,
None)?;
- let eq_array = match from_arr.data_type() {
- // arrow_ord::cmp_eq does not support ListArray, so we
need to compare it by loop
+ match list_array_row {
+ Some(list_array_row) => {
+ let indices = UInt32Array::from(vec![row_index as u32]);
+ let from_array_row = arrow::compute::take(from_array,
&indices, None)?;
+ // Compute all positions in list_row_array (that is itself an
+ // array) that are equal to `from_array_row`
+ let eq_array = match from_array_row.data_type() {
+ // arrow_ord::cmp::eq does not support ListArray, so we
need to compare it by loop
DataType::List(_) => {
- let from_a = as_list_array(&from_arr)?.value(0);
- let list_arr = as_list_array(&arr)?;
+ // compare each element of the from array
+ let from_array_row_inner =
+ as_list_array(&from_array_row)?.value(0);
+ let list_array_row_inner =
as_list_array(&list_array_row)?;
- let mut bool_values = vec![];
- for arr in list_arr.iter() {
- if let Some(a) = arr {
- bool_values.push(Some(a.eq(&from_a)));
- } else {
- return internal_err!(
- "Null value is not supported in
array_replace"
- );
- }
- }
- BooleanArray::from(bool_values)
+ list_array_row_inner
+ .iter()
+ // compare element by element the current row of
list_array
+ .map(|row| row.map(|row|
row.eq(&from_array_row_inner)))
+ .collect::<BooleanArray>()
}
_ => {
- let from_arr = Scalar::new(from_arr);
- arrow_ord::cmp::eq(&arr, &from_arr)?
+ let from_arr = Scalar::new(from_array_row);
+ // use not_distinct so NULL = NULL
+ arrow_ord::cmp::not_distinct(&list_array_row,
&from_arr)?
}
};
// Use MutableArrayData to build the replaced array
+ let original_data = list_array_row.to_data();
+ let to_data = to_array.to_data();
+ let capacity = Capacities::Array(original_data.len() +
to_data.len());
+
// First array is the original array, second array is the
element to replace with.
- let arrays = vec![arr, to_array.clone()];
- let arrays_data = arrays
- .iter()
- .map(|a| a.to_data())
- .collect::<Vec<ArrayData>>();
- let arrays_data =
arrays_data.iter().collect::<Vec<&ArrayData>>();
-
- let arrays = arrays
- .iter()
- .map(|arr| arr.as_ref())
- .collect::<Vec<&dyn Array>>();
- let capacity = Capacities::Array(arrays.iter().map(|a|
a.len()).sum());
-
- let mut mutable =
- MutableArrayData::with_capacities(arrays_data, false,
capacity);
+ let mut mutable = MutableArrayData::with_capacities(
+ vec![&original_data, &to_data],
+ false,
+ capacity,
+ );
+ let original_idx = 0;
+ let replace_idx = 1;
let mut counter = 0;
for (i, to_replace) in eq_array.iter().enumerate() {
- if let Some(to_replace) = to_replace {
- if to_replace {
- mutable.extend(1, row_index, row_index + 1);
- counter += 1;
- if counter == *n {
- // extend the rest of the array
- mutable.extend(0, i + 1, eq_array.len());
- break;
- }
- } else {
- mutable.extend(0, i, i + 1);
+ if let Some(true) = to_replace {
+ mutable.extend(replace_idx, row_index, row_index + 1);
+ counter += 1;
+ if counter == *n {
+ // copy original data for any matches past n
+ mutable.extend(original_idx, i + 1,
eq_array.len());
+ break;
}
} else {
- return internal_err!("eq_array should not contain
None");
+ // copy original data for false / null matches
+ mutable.extend(original_idx, i, i + 1);
}
}
let data = mutable.freeze();
let replaced_array = arrow_array::make_array(data);
- let v = arrow::compute::concat(&[&values, &replaced_array])?;
- values = v;
offsets.push(last_offset + replaced_array.len() as i32);
+ new_values.push(replaced_array);
}
None => {
+ // Null element results in a null row (no new offsets)
offsets.push(last_offset);
}
}
}
+ let values = if new_values.is_empty() {
+ new_empty_array(&data_type)
+ } else {
+ let new_values: Vec<_> = new_values.iter().map(|a|
a.as_ref()).collect();
+ arrow::compute::concat(&new_values)?
+ };
+
Ok(Arc::new(ListArray::try_new(
Arc::new(Field::new("item", data_type, true)),
OffsetBuffer::new(offsets.into()),
values,
- None,
+ list_array.nulls().cloned(),
)?))
}
pub fn array_replace(args: &[ArrayRef]) -> Result<ArrayRef> {
- general_replace(args, vec![1; args[0].len()])
+ // replace at most one occurence for each element
+ let arr_n = vec![1; args[0].len()];
+ general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n)
}
pub fn array_replace_n(args: &[ArrayRef]) -> Result<ArrayRef> {
- let arr = as_int64_array(&args[3])?;
- let arr_n = arr.values().to_vec();
- general_replace(args, arr_n)
+ // replace the specified number of occurences
+ let arr_n = as_int64_array(&args[3])?.values().to_vec();
+ general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n)
}
pub fn array_replace_all(args: &[ArrayRef]) -> Result<ArrayRef> {
- general_replace(args, vec![i64::MAX; args[0].len()])
+ // replace all occurences (up to "i64::MAX")
+ let arr_n = vec![i64::MAX; args[0].len()];
+ general_replace(as_list_array(&args[0])?, &args[1], &args[2], arr_n)
}
macro_rules! to_string {
diff --git a/datafusion/sqllogictest/test_files/array.slt
b/datafusion/sqllogictest/test_files/array.slt
index 85218efb5e..c57369c167 100644
--- a/datafusion/sqllogictest/test_files/array.slt
+++ b/datafusion/sqllogictest/test_files/array.slt
@@ -1720,6 +1720,37 @@ select array_replace_all(make_array([1, 2, 3], [4, 5,
6], [4, 5, 6], [10, 11, 12
[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12],
[10, 11, 12], [28, 29, 30], [28, 29, 30], [28, 29, 30], [28, 29, 30], [22, 23,
24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21],
[25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[11, 12,
13], [11, 12, 13], [11, 12, 13], [22, 23, 24], [11, 12, 13], [25, 26, 27], [11,
12, 13], [22, 23, 24], [11, 12, 13], [11, 12, 13]]
[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12],
[10, 11, 12], [19, 20, 21], [19, 20, 21], [37, 38, 39], [19, 20, 21], [22, 23,
24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33],
[34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[11, 12,
13], [31, 32, 33], [34, 35, 36], [11, 12, 13], [31, 32, 33], [34, 35, 36], [11,
12, 13], [31, 32, 33], [34, 35, 36], [11, 12, 13]]
+# array_replace with null handling
+
+statement ok
+create table t as values
+ (make_array(3, 1, NULL, 3), 3, 4, 2),
+ (make_array(3, 1, NULL, 3), NULL, 5, 2),
+ (NULL, 3, 2, 1),
+ (make_array(3, 1, 3), 3, NULL, 1)
+;
+
+
+# ([3, 1, NULL, 3], 3, 4, 2) => [4, 1, NULL, 4] NULL not matched
+# ([3, 1, NULL, 3], NULL, 5, 2) => [3, 1, NULL, 3] NULL is replaced with 5
+# ([NULL], 3, 2, 1) => NULL
+# ([3, 1, 3], 3, NULL, 1) => [NULL, 1 3]
+
+query ?III?
+select column1, column2, column3, column4, array_replace_n(column1, column2,
column3, column4) from t;
+----
+[3, 1, , 3] 3 4 2 [4, 1, , 4]
+[3, 1, , 3] NULL 5 2 [3, 1, 5, 3]
+NULL 3 2 1 NULL
+[3, 1, 3] 3 NULL 1 [, 1, 3]
+
+
+
+statement ok
+drop table t;
+
+
+
## array_to_string (aliases: `list_to_string`, `array_join`, `list_join`)
# array_to_string scalar function #1