This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 4a3919b82 Add typed dictionary (#2136) (#2297)
4a3919b82 is described below

commit 4a3919b8226441ea5115699287839373c4c4c0c9
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Fri Aug 5 11:29:29 2022 +0100

    Add typed dictionary (#2136) (#2297)
    
    * Add typed dictionary (#2136)
    
    * Review feedback
---
 arrow/src/array/array_dictionary.rs | 125 ++++++++++++++++++++++++++++++++++--
 arrow/src/array/mod.rs              |   2 +-
 2 files changed, 120 insertions(+), 7 deletions(-)

diff --git a/arrow/src/array/array_dictionary.rs 
b/arrow/src/array/array_dictionary.rs
index 4f7d5f9c1..2afc7a69e 100644
--- a/arrow/src/array/array_dictionary.rs
+++ b/arrow/src/array/array_dictionary.rs
@@ -15,6 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use crate::array::{ArrayAccessor, ArrayIter};
 use std::any::Any;
 use std::fmt;
 use std::iter::IntoIterator;
@@ -234,6 +235,28 @@ impl<K: ArrowPrimitiveType> DictionaryArray<K> {
                 .expect("Dictionary index not usize")
         })
     }
+
+    /// Downcast this dictionary to a [`TypedDictionaryArray`]
+    ///
+    /// ```
+    /// use arrow::array::{Array, ArrayAccessor, DictionaryArray, StringArray};
+    /// use arrow::datatypes::Int32Type;
+    ///
+    /// let orig = [Some("a"), Some("b"), None];
+    /// let dictionary = DictionaryArray::<Int32Type>::from_iter(orig);
+    /// let typed = dictionary.downcast_dict::<StringArray>().unwrap();
+    /// assert_eq!(typed.value(0), "a");
+    /// assert_eq!(typed.value(1), "b");
+    /// assert!(typed.is_null(2));
+    /// ```
+    ///
+    pub fn downcast_dict<V: 'static>(&self) -> Option<TypedDictionaryArray<'_, 
K, V>> {
+        let values = self.values.as_any().downcast_ref()?;
+        Some(TypedDictionaryArray {
+            dictionary: self,
+            values,
+        })
+    }
 }
 
 /// Constructs a `DictionaryArray` from an array data reference.
@@ -302,9 +325,7 @@ impl<T: ArrowPrimitiveType> From<DictionaryArray<T>> for 
ArrayData {
 ///     format!("{:?}", array)
 /// );
 /// ```
-impl<'a, T: ArrowPrimitiveType + ArrowDictionaryKeyType> 
FromIterator<Option<&'a str>>
-    for DictionaryArray<T>
-{
+impl<'a, T: ArrowDictionaryKeyType> FromIterator<Option<&'a str>> for 
DictionaryArray<T> {
     fn from_iter<I: IntoIterator<Item = Option<&'a str>>>(iter: I) -> Self {
         let it = iter.into_iter();
         let (lower, _) = it.size_hint();
@@ -342,9 +363,7 @@ impl<'a, T: ArrowPrimitiveType + ArrowDictionaryKeyType> 
FromIterator<Option<&'a
 ///     format!("{:?}", array)
 /// );
 /// ```
-impl<'a, T: ArrowPrimitiveType + ArrowDictionaryKeyType> FromIterator<&'a str>
-    for DictionaryArray<T>
-{
+impl<'a, T: ArrowDictionaryKeyType> FromIterator<&'a str> for 
DictionaryArray<T> {
     fn from_iter<I: IntoIterator<Item = &'a str>>(iter: I) -> Self {
         let it = iter.into_iter();
         let (lower, _) = it.size_hint();
@@ -385,6 +404,100 @@ impl<T: ArrowPrimitiveType> fmt::Debug for 
DictionaryArray<T> {
     }
 }
 
+/// A strongly-typed wrapper around a [`DictionaryArray`] that implements 
[`ArrayAccessor`]
+/// allowing fast access to its elements
+///
+/// ```
+/// use arrow::array::{ArrayIter, DictionaryArray, StringArray};
+/// use arrow::datatypes::Int32Type;
+///
+/// let orig = ["a", "b", "a", "b"];
+/// let dictionary = DictionaryArray::<Int32Type>::from_iter(orig);
+///
+/// // `TypedDictionaryArray` allows you to access the values directly
+/// let typed = dictionary.downcast_dict::<StringArray>().unwrap();
+///
+/// for (maybe_val, orig) in typed.into_iter().zip(orig) {
+///     assert_eq!(maybe_val.unwrap(), orig)
+/// }
+/// ```
+#[derive(Copy, Clone)]
+pub struct TypedDictionaryArray<'a, K: ArrowPrimitiveType, V> {
+    /// The dictionary array
+    dictionary: &'a DictionaryArray<K>,
+    /// The values of the dictionary
+    values: &'a V,
+}
+
+impl<'a, K: ArrowPrimitiveType, V> fmt::Debug for TypedDictionaryArray<'a, K, 
V> {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        writeln!(f, "TypedDictionaryArray({:?})", self.dictionary)
+    }
+}
+
+impl<'a, K: ArrowPrimitiveType, V> TypedDictionaryArray<'a, K, V> {
+    /// Returns the keys of this [`TypedDictionaryArray`]
+    pub fn keys(&self) -> &'a PrimitiveArray<K> {
+        self.dictionary.keys()
+    }
+
+    /// Returns the values of this [`TypedDictionaryArray`]
+    pub fn values(&self) -> &'a V {
+        self.values
+    }
+}
+
+impl<'a, K: ArrowPrimitiveType, V: Sync> Array for TypedDictionaryArray<'a, K, 
V> {
+    fn as_any(&self) -> &dyn Any {
+        self.dictionary
+    }
+
+    fn data(&self) -> &ArrayData {
+        &self.dictionary.data
+    }
+
+    fn into_data(self) -> ArrayData {
+        self.dictionary.into_data()
+    }
+}
+
+impl<'a, K, V> IntoIterator for TypedDictionaryArray<'a, K, V>
+where
+    K: ArrowPrimitiveType,
+    V: Sync + Send,
+    &'a V: ArrayAccessor,
+{
+    type Item = Option<<Self as ArrayAccessor>::Item>;
+    type IntoIter = ArrayIter<Self>;
+
+    fn into_iter(self) -> Self::IntoIter {
+        ArrayIter::new(self)
+    }
+}
+
+impl<'a, K, V> ArrayAccessor for TypedDictionaryArray<'a, K, V>
+where
+    K: ArrowPrimitiveType,
+    V: Sync + Send,
+    &'a V: ArrayAccessor,
+{
+    type Item = <&'a V as ArrayAccessor>::Item;
+
+    fn value(&self, index: usize) -> Self::Item {
+        assert!(self.dictionary.is_valid(index), "{}", index);
+        let value_idx = self.dictionary.keys.value(index).to_usize().unwrap();
+        // Dictionary indexes should be valid
+        unsafe { self.values.value_unchecked(value_idx) }
+    }
+
+    unsafe fn value_unchecked(&self, index: usize) -> Self::Item {
+        let val = self.dictionary.keys.value_unchecked(index);
+        let value_idx = val.to_usize().unwrap();
+        // Dictionary indexes should be valid
+        self.values.value_unchecked(value_idx)
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
diff --git a/arrow/src/array/mod.rs b/arrow/src/array/mod.rs
index 3785af85a..4a7667741 100644
--- a/arrow/src/array/mod.rs
+++ b/arrow/src/array/mod.rs
@@ -208,7 +208,7 @@ pub use self::array_fixed_size_list::FixedSizeListArray;
 #[deprecated(note = "Please use `Decimal128Array` instead")]
 pub type DecimalArray = Decimal128Array;
 
-pub use self::array_dictionary::DictionaryArray;
+pub use self::array_dictionary::{DictionaryArray, TypedDictionaryArray};
 pub use self::array_list::LargeListArray;
 pub use self::array_list::ListArray;
 pub use self::array_map::MapArray;

Reply via email to