This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new df4906d76 Support building comparator for dictionaries of primitive 
integer values (#2673)
df4906d76 is described below

commit df4906d76992e26b7b196c1680755ca360272650
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Wed Sep 7 17:10:38 2022 -0700

    Support building comparator for dictionaries of primitive integer values 
(#2673)
    
    * Support comparing dictionary of primitive value.
    
    * Change to generic function
    
    * Trigger Build
---
 arrow/src/array/ord.rs | 129 ++++++++++++++++++++++++++++++++++++++++---------
 1 file changed, 107 insertions(+), 22 deletions(-)

diff --git a/arrow/src/array/ord.rs b/arrow/src/array/ord.rs
index dd6539589..998c06e50 100644
--- a/arrow/src/array/ord.rs
+++ b/arrow/src/array/ord.rs
@@ -80,6 +80,31 @@ where
     Box::new(move |i, j| left.value(i).cmp(right.value(j)))
 }
 
+fn compare_dict_primitive<K, V>(left: &dyn Array, right: &dyn Array) -> 
DynComparator
+where
+    K: ArrowDictionaryKeyType,
+    V: ArrowPrimitiveType,
+    V::Native: Ord,
+{
+    let left = left.as_any().downcast_ref::<DictionaryArray<K>>().unwrap();
+    let right = right.as_any().downcast_ref::<DictionaryArray<K>>().unwrap();
+
+    let left_keys: PrimitiveArray<K> = 
PrimitiveArray::from(left.keys().data().clone());
+    let right_keys: PrimitiveArray<K> = 
PrimitiveArray::from(right.keys().data().clone());
+    let left_values: PrimitiveArray<V> =
+        PrimitiveArray::from(left.values().data().clone());
+    let right_values: PrimitiveArray<V> =
+        PrimitiveArray::from(right.values().data().clone());
+
+    Box::new(move |i: usize, j: usize| {
+        let key_left = left_keys.value(i).to_usize().unwrap();
+        let key_right = right_keys.value(j).to_usize().unwrap();
+        let left = left_values.value(key_left);
+        let right = right_values.value(key_right);
+        left.cmp(&right)
+    })
+}
+
 fn compare_dict_string<T>(left: &dyn Array, right: &dyn Array) -> DynComparator
 where
     T: ArrowDictionaryKeyType,
@@ -101,6 +126,35 @@ where
     })
 }
 
+fn cmp_dict_primitive<VT>(
+    key_type: &DataType,
+    left: &dyn Array,
+    right: &dyn Array,
+) -> Result<DynComparator>
+where
+    VT: ArrowPrimitiveType,
+    VT::Native: Ord,
+{
+    use DataType::*;
+
+    Ok(match key_type {
+        UInt8 => compare_dict_primitive::<UInt8Type, VT>(left, right),
+        UInt16 => compare_dict_primitive::<UInt16Type, VT>(left, right),
+        UInt32 => compare_dict_primitive::<UInt32Type, VT>(left, right),
+        UInt64 => compare_dict_primitive::<UInt64Type, VT>(left, right),
+        Int8 => compare_dict_primitive::<Int8Type, VT>(left, right),
+        Int16 => compare_dict_primitive::<Int16Type, VT>(left, right),
+        Int32 => compare_dict_primitive::<Int32Type, VT>(left, right),
+        Int64 => compare_dict_primitive::<Int64Type, VT>(left, right),
+        t => {
+            return Err(ArrowError::InvalidArgumentError(format!(
+                "Dictionaries do not support keys of type {:?}",
+                t
+            )));
+        }
+    })
+}
+
 /// returns a comparison function that compares two values at two different 
positions
 /// between the two arrays.
 /// The arrays' types must be equal.
@@ -195,32 +249,43 @@ pub fn build_compare(left: &dyn Array, right: &dyn Array) 
-> Result<DynComparato
             Dictionary(key_type_lhs, value_type_lhs),
             Dictionary(key_type_rhs, value_type_rhs),
         ) => {
-            if value_type_lhs.as_ref() != &DataType::Utf8
-                || value_type_rhs.as_ref() != &DataType::Utf8
-            {
+            if key_type_lhs != key_type_rhs || value_type_lhs != 
value_type_rhs {
                 return Err(ArrowError::InvalidArgumentError(
-                    "Arrow still does not support comparisons of non-string 
dictionary arrays"
-                        .to_string(),
+                    "Can't compare arrays of different types".to_string(),
                 ));
             }
-            match (key_type_lhs.as_ref(), key_type_rhs.as_ref()) {
-                (a, b) if a != b => {
-                    return Err(ArrowError::InvalidArgumentError(
-                        "Can't compare arrays of different types".to_string(),
-                    ));
-                }
-                (UInt8, UInt8) => compare_dict_string::<UInt8Type>(left, 
right),
-                (UInt16, UInt16) => compare_dict_string::<UInt16Type>(left, 
right),
-                (UInt32, UInt32) => compare_dict_string::<UInt32Type>(left, 
right),
-                (UInt64, UInt64) => compare_dict_string::<UInt64Type>(left, 
right),
-                (Int8, Int8) => compare_dict_string::<Int8Type>(left, right),
-                (Int16, Int16) => compare_dict_string::<Int16Type>(left, 
right),
-                (Int32, Int32) => compare_dict_string::<Int32Type>(left, 
right),
-                (Int64, Int64) => compare_dict_string::<Int64Type>(left, 
right),
-                (lhs, _) => {
+
+            let key_type_lhs = key_type_lhs.as_ref();
+
+            match value_type_lhs.as_ref() {
+                Int8 => cmp_dict_primitive::<Int8Type>(key_type_lhs, left, 
right)?,
+                Int16 => cmp_dict_primitive::<Int16Type>(key_type_lhs, left, 
right)?,
+                Int32 => cmp_dict_primitive::<Int32Type>(key_type_lhs, left, 
right)?,
+                Int64 => cmp_dict_primitive::<Int64Type>(key_type_lhs, left, 
right)?,
+                UInt8 => cmp_dict_primitive::<UInt8Type>(key_type_lhs, left, 
right)?,
+                UInt16 => cmp_dict_primitive::<UInt16Type>(key_type_lhs, left, 
right)?,
+                UInt32 => cmp_dict_primitive::<UInt32Type>(key_type_lhs, left, 
right)?,
+                UInt64 => cmp_dict_primitive::<UInt64Type>(key_type_lhs, left, 
right)?,
+                Utf8 => match key_type_lhs {
+                    UInt8 => compare_dict_string::<UInt8Type>(left, right),
+                    UInt16 => compare_dict_string::<UInt16Type>(left, right),
+                    UInt32 => compare_dict_string::<UInt32Type>(left, right),
+                    UInt64 => compare_dict_string::<UInt64Type>(left, right),
+                    Int8 => compare_dict_string::<Int8Type>(left, right),
+                    Int16 => compare_dict_string::<Int16Type>(left, right),
+                    Int32 => compare_dict_string::<Int32Type>(left, right),
+                    Int64 => compare_dict_string::<Int64Type>(left, right),
+                    lhs => {
+                        return Err(ArrowError::InvalidArgumentError(format!(
+                            "Dictionaries do not support keys of type {:?}",
+                            lhs
+                        )));
+                    }
+                },
+                t => {
                     return Err(ArrowError::InvalidArgumentError(format!(
-                        "Dictionaries do not support keys of type {:?}",
-                        lhs
+                        "Dictionaries of value data type {:?} are not 
supported",
+                        t
                     )));
                 }
             }
@@ -339,4 +404,24 @@ pub mod tests {
         assert_eq!(Ordering::Greater, (cmp)(1, 3));
         Ok(())
     }
+
+    #[test]
+    fn test_primitive_dict() -> Result<()> {
+        let values = Int32Array::from(vec![1_i32, 0, 2, 5]);
+        let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
+        let array1 = DictionaryArray::<Int8Type>::try_new(&keys, 
&values).unwrap();
+
+        let values = Int32Array::from(vec![2_i32, 3, 4, 5]);
+        let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
+        let array2 = DictionaryArray::<Int8Type>::try_new(&keys, 
&values).unwrap();
+
+        let cmp = build_compare(&array1, &array2)?;
+
+        assert_eq!(Ordering::Less, (cmp)(0, 0));
+        assert_eq!(Ordering::Less, (cmp)(0, 3));
+        assert_eq!(Ordering::Equal, (cmp)(3, 3));
+        assert_eq!(Ordering::Greater, (cmp)(3, 1));
+        assert_eq!(Ordering::Greater, (cmp)(3, 2));
+        Ok(())
+    }
 }

Reply via email to