This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new a817192f5 Check predicate and values are the same length for run end
array filter safety (#6675)
a817192f5 is described below
commit a817192f5a386b6ff35a893f99209572ccd18c96
Author: delamarch3 <[email protected]>
AuthorDate: Mon Nov 4 22:57:16 2024 +0000
Check predicate and values are the same length for run end array filter
safety (#6675)
* ensure predicate and values have the same length before passing on to
filter_run_end_array
* fix error wording
* have filter_run_end_array use filter array with run_ends max value size
* use skip and take to iterate over filter values in filter_run_end_array
* check array values in max_value_gt_predicate_len test
---
arrow-select/src/filter.rs | 33 ++++++++++++++++++++++++++++-----
1 file changed, 28 insertions(+), 5 deletions(-)
diff --git a/arrow-select/src/filter.rs b/arrow-select/src/filter.rs
index d616b01e7..d96dad2f1 100644
--- a/arrow-select/src/filter.rs
+++ b/arrow-select/src/filter.rs
@@ -436,18 +436,21 @@ where
let mut start = 0i64;
let mut i = 0;
- let filter_values = pred.filter.values();
let mut count = R::default_value();
+ let filter_values = pred.filter.values();
for end in run_ends.inner().into_iter().map(|i| (*i).into()) {
let mut keep = false;
- // in filter_array the predicate array is checked to have the same len
as the run end array
- // this means the largest value in the run_ends is == to pred.len()
- // so we're always within bounds when calling value_unchecked
- for pred in (start..end).map(|i| unsafe {
filter_values.value_unchecked(i as usize) }) {
+
+ for pred in filter_values
+ .iter()
+ .skip(start as usize)
+ .take((end - start) as usize)
+ {
count += R::Native::from(pred);
keep |= pred
}
+
// this is to avoid branching
new_run_ends[i] = count;
i += keep as usize;
@@ -1280,6 +1283,26 @@ mod tests {
assert_eq!(0, actual.len());
}
+ #[test]
+ fn test_filter_run_end_encoding_array_max_value_gt_predicate_len() {
+ let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
+ let values = Int64Array::from(vec![7, -2, 9, -8]);
+ let a = RunArray::try_new(&run_ends, &values).expect("Failed to create
RunArray");
+ let b = BooleanArray::from(vec![false, true, true]);
+ let c = filter(&a, &b).unwrap();
+ let actual: &RunArray<Int64Type> = as_run_array(&c);
+ assert_eq!(2, actual.len());
+
+ let expected = RunArray::try_new(
+ &Int64Array::from(vec![1, 2]),
+ &Int64Array::from(vec![7, -2]),
+ )
+ .expect("Failed to make expected RunArray test is broken");
+
+ assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
+ assert_eq!(actual.values(), expected.values())
+ }
+
#[test]
fn test_filter_dictionary_array() {
let values = [Some("hello"), None, Some("world"), Some("!")];