jorgecarleitao commented on a change in pull request #8517:
URL: https://github.com/apache/arrow/pull/8517#discussion_r513161172



##########
File path: rust/arrow/src/array/ord.rs
##########
@@ -15,297 +15,259 @@
 // specific language governing permissions and limitations
 // under the License.
 
-//! Defines trait for array element comparison
+//! Contains functions and function factories to compare arrays.
 
 use std::cmp::Ordering;
 
 use crate::array::*;
+use crate::datatypes::TimeUnit;
 use crate::datatypes::*;
 use crate::error::{ArrowError, Result};
 
-use TimeUnit::*;
+use num::Float;
 
-/// Trait for Arrays that can be sorted
-///
-/// Example:
-/// ```
-/// use std::cmp::Ordering;
-/// use arrow::array::*;
-/// use arrow::datatypes::*;
-///
-/// let arr: Box<dyn OrdArray> = 
Box::new(PrimitiveArray::<Int64Type>::from(vec![
-///     Some(-2),
-///     Some(89),
-///     Some(-64),
-///     Some(101),
-/// ]));
-///
-/// assert_eq!(arr.cmp_value(1, 2), Ordering::Greater);
-/// ```
-pub trait OrdArray {
-    /// Return ordering between array element at index i and j
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering;
-}
+/// The public interface to compare values from arrays in a dynamically-typed 
fashion.
+pub type DynComparator<'a> = Box<dyn Fn(usize, usize) -> Ordering + 'a>;
 
-impl<T: OrdArray> OrdArray for Box<T> {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
+/// compares two floats, placing NaNs at last
+fn cmp_nans_last<T: Float>(a: &T, b: &T) -> Ordering {
+    match (a, b) {
+        (x, y) if x.is_nan() && y.is_nan() => Ordering::Equal,
+        (x, _) if x.is_nan() => Ordering::Greater,
+        (_, y) if y.is_nan() => Ordering::Less,
+        (_, _) => a.partial_cmp(b).unwrap(),
     }
 }
 
-impl<T: OrdArray> OrdArray for &T {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
-    }
-}
-
-impl<T: ArrowPrimitiveType> OrdArray for PrimitiveArray<T>
+fn compare_primitives<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
 where
-    T::Native: std::cmp::Ord,
+    T::Native: Ord,
 {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(&self.value(j))
-    }
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-impl OrdArray for StringArray {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(self.value(j))
-    }
-}
-
-impl OrdArray for NullArray {
-    fn cmp_value(&self, _i: usize, _j: usize) -> Ordering {
-        Ordering::Equal
-    }
-}
-
-macro_rules! float_ord_cmp {
-    ($NAME: ident, $T: ty) => {
-        #[inline]
-        fn $NAME(a: $T, b: $T) -> Ordering {
-            if a < b {
-                return Ordering::Less;
-            }
-            if a > b {
-                return Ordering::Greater;
-            }
-
-            // convert to bits with canonical pattern for NaN
-            let a = if a.is_nan() {
-                <$T>::NAN.to_bits()
-            } else {
-                a.to_bits()
-            };
-            let b = if b.is_nan() {
-                <$T>::NAN.to_bits()
-            } else {
-                b.to_bits()
-            };
-
-            if a == b {
-                // Equal or both NaN
-                Ordering::Equal
-            } else if a < b {
-                // (-0.0, 0.0) or (!NaN, NaN)
-                Ordering::Less
-            } else {
-                // (0.0, -0.0) or (NaN, !NaN)
-                Ordering::Greater
-            }
-        }
-    };
-}
-
-float_ord_cmp!(cmp_f64, f64);
-float_ord_cmp!(cmp_f32, f32);
-
-#[repr(transparent)]
-struct Float64ArrayAsOrdArray<'a>(&'a Float64Array);
-#[repr(transparent)]
-struct Float32ArrayAsOrdArray<'a>(&'a Float32Array);
-
-impl OrdArray for Float64ArrayAsOrdArray<'_> {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        let a: f64 = self.0.value(i);
-        let b: f64 = self.0.value(j);
-
-        cmp_f64(a, b)
-    }
-}
-
-impl OrdArray for Float32ArrayAsOrdArray<'_> {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        let a: f32 = self.0.value(i);
-        let b: f32 = self.0.value(j);
-
-        cmp_f32(a, b)
-    }
-}
-
-fn float32_as_ord_array<'a>(array: &'a ArrayRef) -> Box<dyn OrdArray + 'a> {
-    let float_array: &Float32Array = as_primitive_array::<Float32Type>(array);
-    Box::new(Float32ArrayAsOrdArray(float_array))
-}
-
-fn float64_as_ord_array<'a>(array: &'a ArrayRef) -> Box<dyn OrdArray + 'a> {
-    let float_array: &Float64Array = as_primitive_array::<Float64Type>(array);
-    Box::new(Float64ArrayAsOrdArray(float_array))
-}
-
-struct StringDictionaryArrayAsOrdArray<'a, T: ArrowDictionaryKeyType> {
-    dict_array: &'a DictionaryArray<T>,
-    values: StringArray,
-    keys: PrimitiveArray<T>,
+fn compare_float<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
+where
+    T::Native: Float,
+{
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| cmp_nans_last(&left.value(i), &right.value(j)))
 }
 
-impl<T: ArrowDictionaryKeyType> OrdArray for 
StringDictionaryArrayAsOrdArray<'_, T> {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        let keys = &self.keys;
-        let dict = &self.values;
-
-        let key_a: T::Native = keys.value(i);
-        let key_b: T::Native = keys.value(j);
-
-        let str_a = dict.value(key_a.to_usize().unwrap());
-        let str_b = dict.value(key_b.to_usize().unwrap());
-
-        str_a.cmp(str_b)
-    }
+fn compare_string<'a, T>(left: &'a Array, right: &'a Array) -> 
DynComparator<'a>
+where
+    T: StringOffsetSizeTrait,
+{
+    let left = left
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    let right = right
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-fn string_dict_as_ord_array<'a, T: ArrowDictionaryKeyType>(
-    array: &'a ArrayRef,
-) -> Box<dyn OrdArray + 'a>
+fn compare_dict_string<'a, T>(left: &'a Array, right: &'a Array) -> 
DynComparator<'a>
 where
-    T::Native: std::cmp::Ord,
+    T: ArrowDictionaryKeyType,
 {
-    let dict_array = as_dictionary_array::<T>(array);
-    let keys = dict_array.keys_array();
-
-    let values = &dict_array.values();
-    let values = StringArray::from(values.data());
-
-    Box::new(StringDictionaryArrayAsOrdArray {
-        dict_array,
-        values,
-        keys,
+    let left = left.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let left_keys = left.keys_array();
+    let right_keys = right.keys_array();
+
+    let left_values = StringArray::from(left.values().data());
+    let right_values = StringArray::from(left.values().data());
+
+    Box::new(move |i: usize, j: usize| {
+        let key_left = left_keys.value(i).to_usize().unwrap();
+        let key_right = right_keys.value(j).to_usize().unwrap();
+        let left = left_values.value(key_left);
+        let right = right_values.value(key_right);
+        left.cmp(&right)
     })
 }
 
-/// Convert ArrayRef to OrdArray trait object
-pub fn as_ordarray<'a>(values: &'a ArrayRef) -> Result<Box<OrdArray + 'a>> {
-    match values.data_type() {
-        DataType::Boolean => Ok(Box::new(as_boolean_array(&values))),
-        DataType::Utf8 => Ok(Box::new(as_string_array(&values))),
-        DataType::Null => Ok(Box::new(as_null_array(&values))),
-        DataType::Int8 => 
Ok(Box::new(as_primitive_array::<Int8Type>(&values))),
-        DataType::Int16 => 
Ok(Box::new(as_primitive_array::<Int16Type>(&values))),
-        DataType::Int32 => 
Ok(Box::new(as_primitive_array::<Int32Type>(&values))),
-        DataType::Int64 => 
Ok(Box::new(as_primitive_array::<Int64Type>(&values))),
-        DataType::UInt8 => 
Ok(Box::new(as_primitive_array::<UInt8Type>(&values))),
-        DataType::UInt16 => 
Ok(Box::new(as_primitive_array::<UInt16Type>(&values))),
-        DataType::UInt32 => 
Ok(Box::new(as_primitive_array::<UInt32Type>(&values))),
-        DataType::UInt64 => 
Ok(Box::new(as_primitive_array::<UInt64Type>(&values))),
-        DataType::Date32(_) => 
Ok(Box::new(as_primitive_array::<Date32Type>(&values))),
-        DataType::Date64(_) => 
Ok(Box::new(as_primitive_array::<Date64Type>(&values))),
-        DataType::Time32(Second) => {
-            Ok(Box::new(as_primitive_array::<Time32SecondType>(&values)))
+/// returns a comparison function that compares two values at two different 
positions
+/// between the two arrays.
+/// The arrays' types must be equal.
+/// # Example
+/// ```
+/// use arrow::array::{build_compare, Int32Array};
+///
+/// # fn main() -> arrow::error::Result<()> {
+/// let array1 = Int32Array::from(vec![1, 2]);
+/// let array2 = Int32Array::from(vec![3, 4]);
+///
+/// let cmp = build_compare(&array1, &array2)?;
+///
+/// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2)
+/// assert_eq!(std::cmp::Ordering::Less, (cmp)(0, 1));
+/// # Ok(())
+/// # }
+/// ```
+// This is a factory of comparisons.
+// The lifetime 'a enforces that we cannot use the closure beyond any of the 
array's lifetime.
+pub fn build_compare<'a>(left: &'a Array, right: &'a Array) -> 
Result<DynComparator<'a>> {

Review comment:
       Note that we were already using dynamic dispatch with the `OrdArray`: in 
the lexical sort, we built a vector of arrays of unknown types, and then call 
their `cmp_values`. Because the vector contains heterogeneous array types, the 
calls are dynamically dispatched.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to