This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new f0200dbec1 Add AnyDictionary Abstraction and Take ArrayRef in
DictionaryArray::with_values (#4707)
f0200dbec1 is described below
commit f0200dbec164c9593d80405b928a9684b598bf77
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Thu Aug 17 10:54:24 2023 +0100
Add AnyDictionary Abstraction and Take ArrayRef in
DictionaryArray::with_values (#4707)
* Add AnyDictionary Abstraction
* Review feedback
* Move to AsArray
---
arrow-arith/src/arity.rs | 8 ++-
arrow-arith/src/temporal.rs | 2 +-
arrow-array/src/array/dictionary_array.rs | 116 ++++++++++++++++++++++++++----
arrow-array/src/cast.rs | 20 ++++++
arrow-row/src/lib.rs | 2 +-
5 files changed, 129 insertions(+), 19 deletions(-)
diff --git a/arrow-arith/src/arity.rs b/arrow-arith/src/arity.rs
index fdfb26f7f7..f3118d1045 100644
--- a/arrow-arith/src/arity.rs
+++ b/arrow-arith/src/arity.rs
@@ -82,7 +82,7 @@ where
{
let dict_values = array.values().as_any().downcast_ref().unwrap();
let values = unary::<T, F, T>(dict_values, op);
- Ok(Arc::new(array.with_values(&values)))
+ Ok(Arc::new(array.with_values(Arc::new(values))))
}
/// A helper function that applies a fallible unary function to a dictionary
array with primitive value type.
@@ -105,10 +105,11 @@ where
let dict_values = array.values().as_any().downcast_ref().unwrap();
let values = try_unary::<T, F, T>(dict_values, op)?;
- Ok(Arc::new(array.with_values(&values)))
+ Ok(Arc::new(array.with_values(Arc::new(values))))
}
/// Applies an infallible unary function to an array with primitive values.
+#[deprecated(note = "Use arrow_array::AnyDictionaryArray")]
pub fn unary_dyn<F, T>(array: &dyn Array, op: F) -> Result<ArrayRef,
ArrowError>
where
T: ArrowPrimitiveType,
@@ -134,6 +135,7 @@ where
}
/// Applies a fallible unary function to an array with primitive values.
+#[deprecated(note = "Use arrow_array::AnyDictionaryArray")]
pub fn try_unary_dyn<F, T>(array: &dyn Array, op: F) -> Result<ArrayRef,
ArrowError>
where
T: ArrowPrimitiveType,
@@ -436,6 +438,7 @@ mod tests {
use arrow_array::types::*;
#[test]
+ #[allow(deprecated)]
fn test_unary_f64_slice() {
let input =
Float64Array::from(vec![Some(5.1f64), None, Some(6.8), None,
Some(7.2)]);
@@ -455,6 +458,7 @@ mod tests {
}
#[test]
+ #[allow(deprecated)]
fn test_unary_dict_and_unary_dyn() {
let mut builder = PrimitiveDictionaryBuilder::<Int8Type,
Int32Type>::new();
builder.append(5).unwrap();
diff --git a/arrow-arith/src/temporal.rs b/arrow-arith/src/temporal.rs
index ef551ceedd..7855b6fc6e 100644
--- a/arrow-arith/src/temporal.rs
+++ b/arrow-arith/src/temporal.rs
@@ -462,7 +462,7 @@ where
downcast_dictionary_array!(
array => {
let values = time_fraction_dyn(array.values(), name, op)?;
- Ok(Arc::new(array.with_values(&values)))
+ Ok(Arc::new(array.with_values(values)))
}
dt => return_compute_error_with!(format!("{name} does not
support"), dt),
)
diff --git a/arrow-array/src/array/dictionary_array.rs
b/arrow-array/src/array/dictionary_array.rs
index 2d80c75f07..ed043754da 100644
--- a/arrow-array/src/array/dictionary_array.rs
+++ b/arrow-array/src/array/dictionary_array.rs
@@ -434,6 +434,7 @@ impl<K: ArrowDictionaryKeyType> DictionaryArray<K> {
/// Panics if `values` has a length less than the current values
///
/// ```
+ /// # use std::sync::Arc;
/// # use arrow_array::builder::PrimitiveDictionaryBuilder;
/// # use arrow_array::{Int8Array, Int64Array, ArrayAccessor};
/// # use arrow_array::types::{Int32Type, Int8Type};
@@ -451,7 +452,7 @@ impl<K: ArrowDictionaryKeyType> DictionaryArray<K> {
/// let values: Int64Array = typed_dictionary.values().unary(|x| x as i64);
///
/// // Create a Dict(Int32,
- /// let new = dictionary.with_values(&values);
+ /// let new = dictionary.with_values(Arc::new(values));
///
/// // Verify values are as expected
/// let new_typed = new.downcast_dict::<Int64Array>().unwrap();
@@ -460,21 +461,18 @@ impl<K: ArrowDictionaryKeyType> DictionaryArray<K> {
/// }
/// ```
///
- pub fn with_values(&self, values: &dyn Array) -> Self {
+ pub fn with_values(&self, values: ArrayRef) -> Self {
assert!(values.len() >= self.values.len());
-
- let builder = self
- .to_data()
- .into_builder()
- .data_type(DataType::Dictionary(
- Box::new(K::DATA_TYPE),
- Box::new(values.data_type().clone()),
- ))
- .child_data(vec![values.to_data()]);
-
- // SAFETY:
- // Offsets were valid before and verified length is greater than or
equal
- Self::from(unsafe { builder.build_unchecked() })
+ let data_type = DataType::Dictionary(
+ Box::new(K::DATA_TYPE),
+ Box::new(values.data_type().clone()),
+ );
+ Self {
+ data_type,
+ keys: self.keys.clone(),
+ values,
+ is_ordered: false,
+ }
}
/// Returns `PrimitiveDictionaryBuilder` of this dictionary array for
mutating
@@ -930,6 +928,94 @@ where
}
}
+/// A [`DictionaryArray`] with the key type erased
+///
+/// This can be used to efficiently implement kernels for all possible
dictionary
+/// keys without needing to create specialized implementations for each key
type
+///
+/// For example
+///
+/// ```
+/// # use arrow_array::*;
+/// # use arrow_array::cast::AsArray;
+/// # use arrow_array::builder::PrimitiveDictionaryBuilder;
+/// # use arrow_array::types::*;
+/// # use arrow_schema::ArrowError;
+/// # use std::sync::Arc;
+///
+/// fn to_string(a: &dyn Array) -> Result<ArrayRef, ArrowError> {
+/// if let Some(d) = a.as_any_dictionary_opt() {
+/// // Recursively handle dictionary input
+/// let r = to_string(d.values().as_ref())?;
+/// return Ok(d.with_values(r));
+/// }
+/// downcast_primitive_array! {
+/// a => Ok(Arc::new(a.iter().map(|x| x.map(|x|
x.to_string())).collect::<StringArray>())),
+/// d => Err(ArrowError::InvalidArgumentError(format!("{d:?} not
supported")))
+/// }
+/// }
+///
+/// let result = to_string(&Int32Array::from(vec![1, 2, 3])).unwrap();
+/// let actual =
result.as_string::<i32>().iter().map(Option::unwrap).collect::<Vec<_>>();
+/// assert_eq!(actual, &["1", "2", "3"]);
+///
+/// let mut dict = PrimitiveDictionaryBuilder::<Int32Type, UInt16Type>::new();
+/// dict.extend([Some(1), Some(1), Some(2), Some(3), Some(2)]);
+/// let dict = dict.finish();
+///
+/// let r = to_string(&dict).unwrap();
+/// let r =
r.as_dictionary::<Int32Type>().downcast_dict::<StringArray>().unwrap();
+/// assert_eq!(r.keys(), dict.keys()); // Keys are the same
+///
+/// let actual = r.into_iter().map(Option::unwrap).collect::<Vec<_>>();
+/// assert_eq!(actual, &["1", "1", "2", "3", "2"]);
+/// ```
+///
+/// See [`AsArray::as_any_dictionary_opt`] and [`AsArray::as_any_dictionary`]
+pub trait AnyDictionaryArray: Array {
+ /// Returns the primitive keys of this dictionary as an [`Array`]
+ fn keys(&self) -> &dyn Array;
+
+ /// Returns the values of this dictionary
+ fn values(&self) -> &ArrayRef;
+
+ /// Returns the keys of this dictionary as usize
+ ///
+ /// The values for nulls will be arbitrary, but are guaranteed
+ /// to be in the range `0..self.values.len()`
+ ///
+ /// # Panic
+ ///
+ /// Panics if `values.len() == 0`
+ fn normalized_keys(&self) -> Vec<usize>;
+
+ /// Create a new [`DictionaryArray`] replacing `values` with the new values
+ ///
+ /// See [`DictionaryArray::with_values`]
+ fn with_values(&self, values: ArrayRef) -> ArrayRef;
+}
+
+impl<K: ArrowDictionaryKeyType> AnyDictionaryArray for DictionaryArray<K> {
+ fn keys(&self) -> &dyn Array {
+ &self.keys
+ }
+
+ fn values(&self) -> &ArrayRef {
+ self.values()
+ }
+
+ fn normalized_keys(&self) -> Vec<usize> {
+ let v_len = self.values().len();
+ assert_ne!(v_len, 0);
+ let iter = self.keys().values().iter();
+ iter.map(|x| x.as_usize().min(v_len)).collect()
+ }
+
+ fn with_values(&self, values: ArrayRef) -> ArrayRef {
+ Arc::new(self.with_values(values))
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
diff --git a/arrow-array/src/cast.rs b/arrow-array/src/cast.rs
index 66b40d5b8e..b6cda44e89 100644
--- a/arrow-array/src/cast.rs
+++ b/arrow-array/src/cast.rs
@@ -833,6 +833,14 @@ pub trait AsArray: private::Sealed {
fn as_dictionary<K: ArrowDictionaryKeyType>(&self) -> &DictionaryArray<K> {
self.as_dictionary_opt().expect("dictionary array")
}
+
+ /// Downcasts this to a [`AnyDictionaryArray`] returning `None` if not
possible
+ fn as_any_dictionary_opt(&self) -> Option<&dyn AnyDictionaryArray>;
+
+ /// Downcasts this to a [`AnyDictionaryArray`] panicking if not possible
+ fn as_any_dictionary(&self) -> &dyn AnyDictionaryArray {
+ self.as_any_dictionary_opt().expect("any dictionary array")
+ }
}
impl private::Sealed for dyn Array + '_ {}
@@ -874,6 +882,14 @@ impl AsArray for dyn Array + '_ {
) -> Option<&DictionaryArray<K>> {
self.as_any().downcast_ref()
}
+
+ fn as_any_dictionary_opt(&self) -> Option<&dyn AnyDictionaryArray> {
+ let array = self;
+ downcast_dictionary_array! {
+ array => Some(array),
+ _ => None
+ }
+ }
}
impl private::Sealed for ArrayRef {}
@@ -915,6 +931,10 @@ impl AsArray for ArrayRef {
) -> Option<&DictionaryArray<K>> {
self.as_ref().as_dictionary_opt()
}
+
+ fn as_any_dictionary_opt(&self) -> Option<&dyn AnyDictionaryArray> {
+ self.as_ref().as_any_dictionary_opt()
+ }
}
#[cfg(test)]
diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs
index 3cd082c511..18b5890d4a 100644
--- a/arrow-row/src/lib.rs
+++ b/arrow-row/src/lib.rs
@@ -1642,7 +1642,7 @@ mod tests {
// Construct dictionary with a timezone
let dict = a.finish();
let values = TimestampNanosecondArray::from(dict.values().to_data());
- let dict_with_tz = dict.with_values(&values.with_timezone("+02:00"));
+ let dict_with_tz =
dict.with_values(Arc::new(values.with_timezone("+02:00")));
let d = DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Timestamp(