This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 810291179f Take kernel dyn Array (#4705)
810291179f is described below

commit 810291179f65d63a5c49ed6b7881bc5788d85a9e
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Thu Aug 17 10:48:33 2023 +0100

    Take kernel dyn Array (#4705)
---
 arrow-cast/src/cast.rs   |  16 +----
 arrow-select/src/take.rs | 153 +++++++++++++++++++++++++++++++++++------------
 2 files changed, 116 insertions(+), 53 deletions(-)

diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs
index a08a7a4fd4..23b7a4b5a0 100644
--- a/arrow-cast/src/cast.rs
+++ b/arrow-cast/src/cast.rs
@@ -49,7 +49,7 @@ use crate::parse::{
 use arrow_array::{
     builder::*, cast::*, temporal_conversions::*, timezone::Tz, types::*, *,
 };
-use arrow_buffer::{i256, ArrowNativeType, Buffer, OffsetBuffer, ScalarBuffer};
+use arrow_buffer::{i256, ArrowNativeType, Buffer, OffsetBuffer};
 use arrow_data::ArrayData;
 use arrow_schema::*;
 use arrow_select::take::take;
@@ -3027,19 +3027,7 @@ where
 {
     let dict_array = array.as_dictionary::<K>();
     let cast_dict_values = cast_with_options(dict_array.values(), to_type, 
cast_options)?;
-    let keys = dict_array.keys();
-    match K::DATA_TYPE {
-        DataType::Int32 => {
-            // Dictionary guarantees all non-null keys >= 0
-            let buffer = ScalarBuffer::new(keys.values().inner().clone(), 0, 
keys.len());
-            let indices = PrimitiveArray::new(buffer, keys.nulls().cloned());
-            take::<UInt32Type>(cast_dict_values.as_ref(), &indices, None)
-        }
-        _ => {
-            let indices = cast_with_options(keys, &DataType::UInt32, 
cast_options)?;
-            take::<UInt32Type>(cast_dict_values.as_ref(), 
indices.as_primitive(), None)
-        }
-    }
+    take(cast_dict_values.as_ref(), dict_array.keys(), None)
 }
 
 /// Attempts to encode an array into an `ArrayDictionary` with index
diff --git a/arrow-select/src/take.rs b/arrow-select/src/take.rs
index cee9cbaf84..70b80e5878 100644
--- a/arrow-select/src/take.rs
+++ b/arrow-select/src/take.rs
@@ -73,49 +73,65 @@ use num::{One, Zero};
 ///
 /// assert_eq!(*taken, StringArray::from(vec!["two", "one"]));
 /// ```
-pub fn take<IndexType: ArrowPrimitiveType>(
+pub fn take(
     values: &dyn Array,
-    indices: &PrimitiveArray<IndexType>,
+    indices: &dyn Array,
     options: Option<TakeOptions>,
 ) -> Result<ArrayRef, ArrowError> {
-    take_impl(values, indices, options)
+    let options = options.unwrap_or_default();
+    macro_rules! helper {
+        ($t:ty, $values:expr, $indices:expr, $options:expr) => {{
+            let indices = indices.as_primitive::<$t>();
+            if $options.check_bounds {
+                check_bounds($values.len(), indices)?;
+            }
+            let indices = indices.to_indices();
+            take_impl($values, &indices)
+        }};
+    }
+    downcast_integer! {
+        indices.data_type() => (helper, values, indices, options),
+        d => Err(ArrowError::InvalidArgumentError(format!("Take only supported 
for integers, got {d:?}")))
+    }
+}
+
+/// Verifies that the non-null values of `indices` are all `< len`
+fn check_bounds<T: ArrowPrimitiveType>(
+    len: usize,
+    indices: &PrimitiveArray<T>,
+) -> Result<(), ArrowError> {
+    if indices.null_count() > 0 {
+        indices.iter().flatten().try_for_each(|index| {
+            let ix = index.to_usize().ok_or_else(|| {
+                ArrowError::ComputeError("Cast to usize failed".to_string())
+            })?;
+            if ix >= len {
+                return Err(ArrowError::ComputeError(
+                    format!("Array index out of bounds, cannot get item at 
index {ix} from {len} entries"))
+                );
+            }
+            Ok(())
+        })
+    } else {
+        indices.values().iter().try_for_each(|index| {
+            let ix = index.to_usize().ok_or_else(|| {
+                ArrowError::ComputeError("Cast to usize failed".to_string())
+            })?;
+            if ix >= len {
+                return Err(ArrowError::ComputeError(
+                    format!("Array index out of bounds, cannot get item at 
index {ix} from {len} entries"))
+                );
+            }
+            Ok(())
+        })
+    }
 }
 
+#[inline(never)]
 fn take_impl<IndexType: ArrowPrimitiveType>(
     values: &dyn Array,
     indices: &PrimitiveArray<IndexType>,
-    options: Option<TakeOptions>,
 ) -> Result<ArrayRef, ArrowError> {
-    let options = options.unwrap_or_default();
-    if options.check_bounds {
-        let len = values.len();
-        if indices.null_count() > 0 {
-            indices.iter().flatten().try_for_each(|index| {
-                let ix = index.to_usize().ok_or_else(|| {
-                    ArrowError::ComputeError("Cast to usize 
failed".to_string())
-                })?;
-                if ix >= len {
-                    return Err(ArrowError::ComputeError(
-                        format!("Array index out of bounds, cannot get item at 
index {ix} from {len} entries"))
-                    );
-                }
-                Ok(())
-            })?;
-        } else {
-            indices.values().iter().try_for_each(|index| {
-                let ix = index.to_usize().ok_or_else(|| {
-                    ArrowError::ComputeError("Cast to usize 
failed".to_string())
-                })?;
-                if ix >= len {
-                    return Err(ArrowError::ComputeError(
-                        format!("Array index out of bounds, cannot get item at 
index {ix} from {len} entries"))
-                    );
-                }
-                Ok(())
-            })?
-        }
-    }
-
     downcast_primitive_array! {
         values => Ok(Arc::new(take_primitive(values, indices)?)),
         DataType::Boolean => {
@@ -156,7 +172,7 @@ fn take_impl<IndexType: ArrowPrimitiveType>(
             let arrays  = array
                 .columns()
                 .iter()
-                .map(|a| take_impl(a.as_ref(), indices, Some(options.clone())))
+                .map(|a| take_impl(a.as_ref(), indices))
                 .collect::<Result<Vec<ArrayRef>, _>>()?;
             let fields: Vec<(FieldRef, ArrayRef)> =
                 fields.iter().cloned().zip(arrays).collect();
@@ -423,7 +439,7 @@ where
     let (list_indices, offsets, null_buf) =
         take_value_indices_from_list::<IndexType, OffsetType>(values, 
indices)?;
 
-    let taken = take_impl::<OffsetType>(values.values().as_ref(), 
&list_indices, None)?;
+    let taken = take_impl::<OffsetType>(values.values().as_ref(), 
&list_indices)?;
     let value_offsets = Buffer::from_vec(offsets);
     // create a new list with taken data and computed null information
     let list_data = ArrayDataBuilder::new(values.data_type().clone())
@@ -449,7 +465,7 @@ fn take_fixed_size_list<IndexType: ArrowPrimitiveType>(
     length: <UInt32Type as ArrowPrimitiveType>::Native,
 ) -> Result<FixedSizeListArray, ArrowError> {
     let list_indices = take_value_indices_from_fixed_size_list(values, 
indices, length)?;
-    let taken = take_impl::<UInt32Type>(values.values().as_ref(), 
&list_indices, None)?;
+    let taken = take_impl::<UInt32Type>(values.values().as_ref(), 
&list_indices)?;
 
     // determine null count and null buffer, which are a function of `values` 
and `indices`
     let num_bytes = bit_util::ceil(indices.len(), 8);
@@ -676,6 +692,65 @@ where
     Ok(PrimitiveArray::<UInt32Type>::from(values))
 }
 
+/// To avoid generating take implementations for every index type, instead we
+/// only generate for UInt32 and UInt64 and coerce inputs to these types
+trait ToIndices {
+    type T: ArrowPrimitiveType;
+
+    fn to_indices(&self) -> PrimitiveArray<Self::T>;
+}
+
+macro_rules! to_indices_reinterpret {
+    ($t:ty, $o:ty) => {
+        impl ToIndices for PrimitiveArray<$t> {
+            type T = $o;
+
+            fn to_indices(&self) -> PrimitiveArray<$o> {
+                let cast =
+                    ScalarBuffer::new(self.values().inner().clone(), 0, 
self.len());
+                PrimitiveArray::new(cast, self.nulls().cloned())
+            }
+        }
+    };
+}
+
+macro_rules! to_indices_identity {
+    ($t:ty) => {
+        impl ToIndices for PrimitiveArray<$t> {
+            type T = $t;
+
+            fn to_indices(&self) -> PrimitiveArray<$t> {
+                self.clone()
+            }
+        }
+    };
+}
+
+macro_rules! to_indices_widening {
+    ($t:ty, $o:ty) => {
+        impl ToIndices for PrimitiveArray<$t> {
+            type T = UInt32Type;
+
+            fn to_indices(&self) -> PrimitiveArray<$o> {
+                let cast = self.values().iter().copied().map(|x| x as 
_).collect();
+                PrimitiveArray::new(cast, self.nulls().cloned())
+            }
+        }
+    };
+}
+
+to_indices_widening!(UInt8Type, UInt32Type);
+to_indices_widening!(Int8Type, UInt32Type);
+
+to_indices_widening!(UInt16Type, UInt32Type);
+to_indices_widening!(Int16Type, UInt32Type);
+
+to_indices_identity!(UInt32Type);
+to_indices_reinterpret!(Int32Type, UInt32Type);
+
+to_indices_identity!(UInt64Type);
+to_indices_reinterpret!(Int64Type, UInt64Type);
+
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -767,7 +842,7 @@ mod tests {
     {
         let output = PrimitiveArray::<T>::from(data);
         let expected = PrimitiveArray::<T>::from(expected_data);
-        let output = take_impl(&output, index, options).unwrap();
+        let output = take(&output, index, options).unwrap();
         let output = 
output.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
         assert_eq!(output, &expected)
     }
@@ -1078,7 +1153,7 @@ mod tests {
             1_639_715_368_000_000_000,
         ])
         .with_timezone("UTC".to_string());
-        let result = take_impl(&input, &index, None).unwrap();
+        let result = take(&input, &index, None).unwrap();
         match result.data_type() {
             DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
                 assert_eq!(tz.clone(), Some("UTC".into()))

Reply via email to