This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new df4906d76 Support building comparator for dictionaries of primitive
integer values (#2673)
df4906d76 is described below
commit df4906d76992e26b7b196c1680755ca360272650
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Wed Sep 7 17:10:38 2022 -0700
Support building comparator for dictionaries of primitive integer values
(#2673)
* Support comparing dictionary of primitive value.
* Change to generic function
* Trigger Build
---
arrow/src/array/ord.rs | 129 ++++++++++++++++++++++++++++++++++++++++---------
1 file changed, 107 insertions(+), 22 deletions(-)
diff --git a/arrow/src/array/ord.rs b/arrow/src/array/ord.rs
index dd6539589..998c06e50 100644
--- a/arrow/src/array/ord.rs
+++ b/arrow/src/array/ord.rs
@@ -80,6 +80,31 @@ where
Box::new(move |i, j| left.value(i).cmp(right.value(j)))
}
+fn compare_dict_primitive<K, V>(left: &dyn Array, right: &dyn Array) ->
DynComparator
+where
+ K: ArrowDictionaryKeyType,
+ V: ArrowPrimitiveType,
+ V::Native: Ord,
+{
+ let left = left.as_any().downcast_ref::<DictionaryArray<K>>().unwrap();
+ let right = right.as_any().downcast_ref::<DictionaryArray<K>>().unwrap();
+
+ let left_keys: PrimitiveArray<K> =
PrimitiveArray::from(left.keys().data().clone());
+ let right_keys: PrimitiveArray<K> =
PrimitiveArray::from(right.keys().data().clone());
+ let left_values: PrimitiveArray<V> =
+ PrimitiveArray::from(left.values().data().clone());
+ let right_values: PrimitiveArray<V> =
+ PrimitiveArray::from(right.values().data().clone());
+
+ Box::new(move |i: usize, j: usize| {
+ let key_left = left_keys.value(i).to_usize().unwrap();
+ let key_right = right_keys.value(j).to_usize().unwrap();
+ let left = left_values.value(key_left);
+ let right = right_values.value(key_right);
+ left.cmp(&right)
+ })
+}
+
fn compare_dict_string<T>(left: &dyn Array, right: &dyn Array) -> DynComparator
where
T: ArrowDictionaryKeyType,
@@ -101,6 +126,35 @@ where
})
}
+fn cmp_dict_primitive<VT>(
+ key_type: &DataType,
+ left: &dyn Array,
+ right: &dyn Array,
+) -> Result<DynComparator>
+where
+ VT: ArrowPrimitiveType,
+ VT::Native: Ord,
+{
+ use DataType::*;
+
+ Ok(match key_type {
+ UInt8 => compare_dict_primitive::<UInt8Type, VT>(left, right),
+ UInt16 => compare_dict_primitive::<UInt16Type, VT>(left, right),
+ UInt32 => compare_dict_primitive::<UInt32Type, VT>(left, right),
+ UInt64 => compare_dict_primitive::<UInt64Type, VT>(left, right),
+ Int8 => compare_dict_primitive::<Int8Type, VT>(left, right),
+ Int16 => compare_dict_primitive::<Int16Type, VT>(left, right),
+ Int32 => compare_dict_primitive::<Int32Type, VT>(left, right),
+ Int64 => compare_dict_primitive::<Int64Type, VT>(left, right),
+ t => {
+ return Err(ArrowError::InvalidArgumentError(format!(
+ "Dictionaries do not support keys of type {:?}",
+ t
+ )));
+ }
+ })
+}
+
/// returns a comparison function that compares two values at two different
positions
/// between the two arrays.
/// The arrays' types must be equal.
@@ -195,32 +249,43 @@ pub fn build_compare(left: &dyn Array, right: &dyn Array)
-> Result<DynComparato
Dictionary(key_type_lhs, value_type_lhs),
Dictionary(key_type_rhs, value_type_rhs),
) => {
- if value_type_lhs.as_ref() != &DataType::Utf8
- || value_type_rhs.as_ref() != &DataType::Utf8
- {
+ if key_type_lhs != key_type_rhs || value_type_lhs !=
value_type_rhs {
return Err(ArrowError::InvalidArgumentError(
- "Arrow still does not support comparisons of non-string
dictionary arrays"
- .to_string(),
+ "Can't compare arrays of different types".to_string(),
));
}
- match (key_type_lhs.as_ref(), key_type_rhs.as_ref()) {
- (a, b) if a != b => {
- return Err(ArrowError::InvalidArgumentError(
- "Can't compare arrays of different types".to_string(),
- ));
- }
- (UInt8, UInt8) => compare_dict_string::<UInt8Type>(left,
right),
- (UInt16, UInt16) => compare_dict_string::<UInt16Type>(left,
right),
- (UInt32, UInt32) => compare_dict_string::<UInt32Type>(left,
right),
- (UInt64, UInt64) => compare_dict_string::<UInt64Type>(left,
right),
- (Int8, Int8) => compare_dict_string::<Int8Type>(left, right),
- (Int16, Int16) => compare_dict_string::<Int16Type>(left,
right),
- (Int32, Int32) => compare_dict_string::<Int32Type>(left,
right),
- (Int64, Int64) => compare_dict_string::<Int64Type>(left,
right),
- (lhs, _) => {
+
+ let key_type_lhs = key_type_lhs.as_ref();
+
+ match value_type_lhs.as_ref() {
+ Int8 => cmp_dict_primitive::<Int8Type>(key_type_lhs, left,
right)?,
+ Int16 => cmp_dict_primitive::<Int16Type>(key_type_lhs, left,
right)?,
+ Int32 => cmp_dict_primitive::<Int32Type>(key_type_lhs, left,
right)?,
+ Int64 => cmp_dict_primitive::<Int64Type>(key_type_lhs, left,
right)?,
+ UInt8 => cmp_dict_primitive::<UInt8Type>(key_type_lhs, left,
right)?,
+ UInt16 => cmp_dict_primitive::<UInt16Type>(key_type_lhs, left,
right)?,
+ UInt32 => cmp_dict_primitive::<UInt32Type>(key_type_lhs, left,
right)?,
+ UInt64 => cmp_dict_primitive::<UInt64Type>(key_type_lhs, left,
right)?,
+ Utf8 => match key_type_lhs {
+ UInt8 => compare_dict_string::<UInt8Type>(left, right),
+ UInt16 => compare_dict_string::<UInt16Type>(left, right),
+ UInt32 => compare_dict_string::<UInt32Type>(left, right),
+ UInt64 => compare_dict_string::<UInt64Type>(left, right),
+ Int8 => compare_dict_string::<Int8Type>(left, right),
+ Int16 => compare_dict_string::<Int16Type>(left, right),
+ Int32 => compare_dict_string::<Int32Type>(left, right),
+ Int64 => compare_dict_string::<Int64Type>(left, right),
+ lhs => {
+ return Err(ArrowError::InvalidArgumentError(format!(
+ "Dictionaries do not support keys of type {:?}",
+ lhs
+ )));
+ }
+ },
+ t => {
return Err(ArrowError::InvalidArgumentError(format!(
- "Dictionaries do not support keys of type {:?}",
- lhs
+ "Dictionaries of value data type {:?} are not
supported",
+ t
)));
}
}
@@ -339,4 +404,24 @@ pub mod tests {
assert_eq!(Ordering::Greater, (cmp)(1, 3));
Ok(())
}
+
+ #[test]
+ fn test_primitive_dict() -> Result<()> {
+ let values = Int32Array::from(vec![1_i32, 0, 2, 5]);
+ let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
+ let array1 = DictionaryArray::<Int8Type>::try_new(&keys,
&values).unwrap();
+
+ let values = Int32Array::from(vec![2_i32, 3, 4, 5]);
+ let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
+ let array2 = DictionaryArray::<Int8Type>::try_new(&keys,
&values).unwrap();
+
+ let cmp = build_compare(&array1, &array2)?;
+
+ assert_eq!(Ordering::Less, (cmp)(0, 0));
+ assert_eq!(Ordering::Less, (cmp)(0, 3));
+ assert_eq!(Ordering::Equal, (cmp)(3, 3));
+ assert_eq!(Ordering::Greater, (cmp)(3, 1));
+ assert_eq!(Ordering::Greater, (cmp)(3, 2));
+ Ok(())
+ }
}