This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new d011e6adc perf: `take_run` improvements (#3705)
d011e6adc is described below

commit d011e6adc6c7ff0ae72784e89cf736112016c9c5
Author: askoa <[email protected]>
AuthorDate: Mon Feb 13 10:50:51 2023 -0500

    perf: `take_run` improvements (#3705)
    
    * take_run improvements
    
    * doc fix
    
    * test case update per pr comment
    
    ---------
    
    Co-authored-by: ask <ask@local>
---
 arrow-select/src/take.rs | 124 +++++++++++++++++++++++++----------------------
 1 file changed, 66 insertions(+), 58 deletions(-)

diff --git a/arrow-select/src/take.rs b/arrow-select/src/take.rs
index f8383bbe3..22991c4f2 100644
--- a/arrow-select/src/take.rs
+++ b/arrow-select/src/take.rs
@@ -19,15 +19,14 @@
 
 use std::sync::Arc;
 
+use arrow_array::builder::BufferBuilder;
+use arrow_array::types::*;
 use arrow_array::*;
-use arrow_array::{builder::PrimitiveRunBuilder, types::*};
 use arrow_buffer::{bit_util, ArrowNativeType, Buffer, MutableBuffer};
 use arrow_data::{ArrayData, ArrayDataBuilder};
 use arrow_schema::{ArrowError, DataType, Field};
 
-use arrow_array::cast::{
-    as_generic_binary_array, as_largestring_array, as_primitive_array, 
as_string_array,
-};
+use arrow_array::cast::{as_generic_binary_array, as_largestring_array, 
as_string_array};
 use num::{ToPrimitive, Zero};
 
 /// Take elements by index from [Array], creating a new [Array] from those 
indexes.
@@ -816,22 +815,14 @@ where
     Ok(DictionaryArray::<T>::from(data))
 }
 
-macro_rules! primitive_run_take {
-    ($t:ty, $o:ty, $indices:ident, $value:ident) => {
-        take_primitive_run_values::<$o, $t>(
-            $indices,
-            as_primitive_array::<$t>($value.values()),
-        )
-    };
-}
-
 /// `take` implementation for run arrays
 ///
 /// Finds physical indices for the given logical indices and builds output run 
array
-/// by taking values in the input run array at the physical indices.
-/// for e.g. an input `RunArray{ run_ends = [2,4,6,8], values=[1,2,1,2] }` and 
`indices=[2,7]`
-/// would be converted to `physical_indices=[1,3]` which will be used to build
-/// output `RunArray{ run_ends=[2], values=[2] }`
+/// by taking values in the input run_array.values at the physical indices.
+/// The output run array will be run encoded on the physical indices and not 
on output values.
+/// For e.g. an input `RunArray{ run_ends = [2,4,6,8], values=[1,2,1,2] }` and 
`logical_indices=[2,3,6,7]`
+/// would be converted to `physical_indices=[1,1,3,3]` which will be used to 
build
+/// output `RunArray{ run_ends=[2,4], values=[2,2] }`.
 fn take_run<T, I>(
     run_array: &RunArray<T>,
     logical_indices: &PrimitiveArray<I>,
@@ -842,43 +833,60 @@ where
     I: ArrowPrimitiveType,
     I::Native: ToPrimitive,
 {
-    match run_array.data_type() {
-        DataType::RunEndEncoded(_, fl) => {
-            let physical_indices =
-                run_array.get_physical_indices(logical_indices.values())?;
-
-            downcast_primitive! {
-                fl.data_type() => (primitive_run_take, T, physical_indices, 
run_array),
-                dt => Err(ArrowError::NotYetImplemented(format!("take_run is 
not implemented for {dt:?}")))
-            }
+    // get physical indices for the input logical indices
+    let physical_indices = 
run_array.get_physical_indices(logical_indices.values())?;
+
+    // Run encode the physical indices into new_run_ends_builder
+    // Keep track of the physical indices to take in take_value_indices
+    // `unwrap` is used in this function because the unwrapped values are 
bounded by the corresponding `::Native`.
+    let mut new_run_ends_builder = BufferBuilder::<T::Native>::new(1);
+    let mut take_value_indices = BufferBuilder::<I::Native>::new(1);
+    let mut new_physical_len = 1;
+    for ix in 1..physical_indices.len() {
+        if physical_indices[ix] != physical_indices[ix - 1] {
+            take_value_indices
+                .append(I::Native::from_usize(physical_indices[ix - 
1]).unwrap());
+            new_run_ends_builder.append(T::Native::from_usize(ix).unwrap());
+            new_physical_len += 1;
         }
-        dt => Err(ArrowError::InvalidArgumentError(format!(
-            "Expected DataType::RunEndEncoded found {dt:?}"
-        ))),
     }
-}
+    take_value_indices.append(
+        I::Native::from_usize(physical_indices[physical_indices.len() - 
1]).unwrap(),
+    );
+    
new_run_ends_builder.append(T::Native::from_usize(physical_indices.len()).unwrap());
+    let new_run_ends = unsafe {
+        // Safety:
+        // The function builds a valid run_ends array and hence need not be 
validated.
+        ArrayDataBuilder::new(T::DATA_TYPE)
+            .len(new_physical_len)
+            .null_count(0)
+            .add_buffer(new_run_ends_builder.finish())
+            .build_unchecked()
+    };
 
-// Builds a `RunArray` by taking values from given array for the given indices.
-fn take_primitive_run_values<R, V>(
-    physical_indices: Vec<usize>,
-    values: &PrimitiveArray<V>,
-) -> Result<RunArray<R>, ArrowError>
-where
-    R: RunEndIndexType,
-    V: ArrowPrimitiveType,
-{
-    let mut builder = PrimitiveRunBuilder::<R, V>::new();
-    let values_len = values.len();
-    for ix in physical_indices {
-        if ix >= values_len {
-            return Err(ArrowError::InvalidArgumentError("The requested index 
{ix} is out of bounds for values array with length {values_len}".to_string()));
-        } else if values.is_null(ix) {
-            builder.append_null()
-        } else {
-            builder.append_value(values.value(ix))
-        }
-    }
-    Ok(builder.finish())
+    let take_value_indices: PrimitiveArray<I> = unsafe {
+        // Safety:
+        // The function builds a valid take_value_indices array and hence need 
not be validated.
+        ArrayDataBuilder::new(I::DATA_TYPE)
+            .len(new_physical_len)
+            .null_count(0)
+            .add_buffer(take_value_indices.finish())
+            .build_unchecked()
+            .into()
+    };
+
+    let new_values = take(run_array.values(), &take_value_indices, None)?;
+
+    let builder = ArrayDataBuilder::new(run_array.data_type().clone())
+        .len(physical_indices.len())
+        .add_child_data(new_run_ends)
+        .add_child_data(new_values.into_data());
+    let array_data = unsafe {
+        // Safety:
+        //  This function builds a valid run array and hence can skip 
validation.
+        builder.build_unchecked()
+    };
+    Ok(array_data.into())
 }
 
 /// Takes/filters a list array's inner data using the offsets of the list 
array.
@@ -983,7 +991,7 @@ where
 #[cfg(test)]
 mod tests {
     use super::*;
-    use arrow_array::builder::*;
+    use arrow_array::{builder::*, cast::as_primitive_array};
     use arrow_schema::TimeUnit;
 
     fn test_take_decimal_arrays(
@@ -2159,24 +2167,24 @@ mod tests {
 
     #[test]
     fn test_take_runs() {
-        let logical_array: Vec<i32> = vec![1_i32, 1, 2, 2, 1, 1, 2, 2, 1, 1, 
2, 2];
+        let logical_array: Vec<i32> = vec![1_i32, 1, 2, 2, 1, 1, 1, 2, 2, 1, 
1, 2, 2];
 
         let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
         builder.extend(logical_array.into_iter().map(Some));
         let run_array = builder.finish();
 
         let take_indices: PrimitiveArray<Int32Type> =
-            vec![2, 7, 10].into_iter().collect();
+            vec![7, 2, 3, 7, 11, 4, 6].into_iter().collect();
 
         let take_out = take_run(&run_array, &take_indices).unwrap();
 
-        assert_eq!(take_out.len(), 3);
+        assert_eq!(take_out.len(), 7);
 
-        assert_eq!(take_out.run_ends().len(), 1);
-        assert_eq!(take_out.run_ends().value(0), 3);
+        assert_eq!(take_out.run_ends().len(), 5);
+        assert_eq!(take_out.run_ends().values(), &[1_i32, 3, 4, 5, 7]);
 
         let take_out_values = 
as_primitive_array::<Int32Type>(take_out.values());
-        assert_eq!(take_out_values.value(0), 2);
+        assert_eq!(take_out_values.values(), &[2, 2, 2, 2, 1]);
     }
 
     #[test]

Reply via email to