This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new ec273e76db Cleanup DynComparator (#2654) (#4687)
ec273e76db is described below

commit ec273e76db12106db0a886529d9018763c11dc9f
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Sun Aug 13 21:38:49 2023 +0100

    Cleanup DynComparator (#2654) (#4687)
---
 arrow-ord/src/ord.rs | 392 ++++++++++++++++++---------------------------------
 1 file changed, 134 insertions(+), 258 deletions(-)

diff --git a/arrow-ord/src/ord.rs b/arrow-ord/src/ord.rs
index a33ead8ab0..4d6e3bde91 100644
--- a/arrow-ord/src/ord.rs
+++ b/arrow-ord/src/ord.rs
@@ -21,114 +21,59 @@ use arrow_array::cast::AsArray;
 use arrow_array::types::*;
 use arrow_array::*;
 use arrow_buffer::ArrowNativeType;
-use arrow_schema::{ArrowError, DataType};
+use arrow_schema::ArrowError;
 use std::cmp::Ordering;
 
 /// Compare the values at two arbitrary indices in two arrays.
 pub type DynComparator = Box<dyn Fn(usize, usize) -> Ordering + Send + Sync>;
 
-fn compare_primitives<T: ArrowPrimitiveType>(
+fn compare_primitive<T: ArrowPrimitiveType>(
     left: &dyn Array,
     right: &dyn Array,
 ) -> DynComparator
 where
     T::Native: ArrowNativeTypeOp,
 {
-    let left: PrimitiveArray<T> = PrimitiveArray::from(left.to_data());
-    let right: PrimitiveArray<T> = PrimitiveArray::from(right.to_data());
+    let left = left.as_primitive::<T>().clone();
+    let right = right.as_primitive::<T>().clone();
     Box::new(move |i, j| left.value(i).compare(right.value(j)))
 }
 
 fn compare_boolean(left: &dyn Array, right: &dyn Array) -> DynComparator {
-    let left: BooleanArray = BooleanArray::from(left.to_data());
-    let right: BooleanArray = BooleanArray::from(right.to_data());
+    let left: BooleanArray = left.as_boolean().clone();
+    let right: BooleanArray = right.as_boolean().clone();
 
     Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-fn compare_string(left: &dyn Array, right: &dyn Array) -> DynComparator {
-    let left: StringArray = StringArray::from(left.to_data());
-    let right: StringArray = StringArray::from(right.to_data());
+fn compare_bytes<T: ByteArrayType>(left: &dyn Array, right: &dyn Array) -> 
DynComparator {
+    let left = left.as_bytes::<T>().clone();
+    let right = right.as_bytes::<T>().clone();
 
-    Box::new(move |i, j| left.value(i).cmp(right.value(j)))
-}
-
-fn compare_dict_primitive<K, V>(left: &dyn Array, right: &dyn Array) -> 
DynComparator
-where
-    K: ArrowDictionaryKeyType,
-    V: ArrowPrimitiveType,
-    V::Native: ArrowNativeTypeOp,
-{
-    let left = left.as_dictionary::<K>();
-    let right = right.as_dictionary::<K>();
-
-    let left_keys: PrimitiveArray<K> = 
PrimitiveArray::from(left.keys().to_data());
-    let right_keys: PrimitiveArray<K> = 
PrimitiveArray::from(right.keys().to_data());
-    let left_values: PrimitiveArray<V> = left.values().to_data().into();
-    let right_values: PrimitiveArray<V> = right.values().to_data().into();
-
-    Box::new(move |i: usize, j: usize| {
-        let key_left = left_keys.value(i).as_usize();
-        let key_right = right_keys.value(j).as_usize();
-        let left = left_values.value(key_left);
-        let right = right_values.value(key_right);
-        left.compare(right)
-    })
-}
-
-fn compare_dict_string<T>(left: &dyn Array, right: &dyn Array) -> DynComparator
-where
-    T: ArrowDictionaryKeyType,
-{
-    let left = left.as_dictionary::<T>();
-    let right = right.as_dictionary::<T>();
-
-    let left_keys: PrimitiveArray<T> = 
PrimitiveArray::from(left.keys().to_data());
-    let right_keys: PrimitiveArray<T> = 
PrimitiveArray::from(right.keys().to_data());
-    let left_values = StringArray::from(left.values().to_data());
-    let right_values = StringArray::from(right.values().to_data());
-
-    Box::new(move |i: usize, j: usize| {
-        let key_left = left_keys.value(i).as_usize();
-        let key_right = right_keys.value(j).as_usize();
-        let left = left_values.value(key_left);
-        let right = right_values.value(key_right);
-        left.cmp(right)
+    Box::new(move |i, j| {
+        let l: &[u8] = left.value(i).as_ref();
+        let r: &[u8] = right.value(j).as_ref();
+        l.cmp(r)
     })
 }
 
-fn cmp_dict_primitive<VT>(
-    key_type: &DataType,
+fn compare_dict<K: ArrowDictionaryKeyType>(
     left: &dyn Array,
     right: &dyn Array,
-) -> Result<DynComparator, ArrowError>
-where
-    VT: ArrowPrimitiveType,
-    VT::Native: ArrowNativeTypeOp,
-{
-    use DataType::*;
-
-    Ok(match key_type {
-        UInt8 => compare_dict_primitive::<UInt8Type, VT>(left, right),
-        UInt16 => compare_dict_primitive::<UInt16Type, VT>(left, right),
-        UInt32 => compare_dict_primitive::<UInt32Type, VT>(left, right),
-        UInt64 => compare_dict_primitive::<UInt64Type, VT>(left, right),
-        Int8 => compare_dict_primitive::<Int8Type, VT>(left, right),
-        Int16 => compare_dict_primitive::<Int16Type, VT>(left, right),
-        Int32 => compare_dict_primitive::<Int32Type, VT>(left, right),
-        Int64 => compare_dict_primitive::<Int64Type, VT>(left, right),
-        t => {
-            return Err(ArrowError::InvalidArgumentError(format!(
-                "Dictionaries do not support keys of type {t:?}"
-            )));
-        }
-    })
-}
+) -> Result<DynComparator, ArrowError> {
+    let left = left.as_dictionary::<K>();
+    let right = right.as_dictionary::<K>();
+
+    let cmp = build_compare(left.values().as_ref(), right.values().as_ref())?;
+    let left_keys = left.keys().clone();
+    let right_keys = right.keys().clone();
 
-macro_rules! cmp_dict_primitive_helper {
-    ($t:ty, $key_type_lhs:expr, $left:expr, $right:expr) => {
-        cmp_dict_primitive::<$t>($key_type_lhs, $left, $right)?
-    };
+    // TODO: Handle value nulls (#2687)
+    Ok(Box::new(move |i, j| {
+        let l = left_keys.value(i).as_usize();
+        let r = right_keys.value(j).as_usize();
+        cmp(l, r)
+    }))
 }
 
 /// returns a comparison function that compares two values at two different 
positions
@@ -145,7 +90,7 @@ macro_rules! cmp_dict_primitive_helper {
 /// let cmp = build_compare(&array1, &array2).unwrap();
 ///
 /// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2)
-/// assert_eq!(std::cmp::Ordering::Less, (cmp)(0, 1));
+/// assert_eq!(std::cmp::Ordering::Less, cmp(0, 1));
 /// ```
 // This is a factory of comparisons.
 // The lifetime 'a enforces that we cannot use the closure beyond any of the 
array's lifetime.
@@ -153,134 +98,47 @@ pub fn build_compare(
     left: &dyn Array,
     right: &dyn Array,
 ) -> Result<DynComparator, ArrowError> {
-    use arrow_schema::{DataType::*, IntervalUnit::*, TimeUnit::*};
-    Ok(match (left.data_type(), right.data_type()) {
-        (a, b) if a != b => {
-            return Err(ArrowError::InvalidArgumentError(
-                "Can't compare arrays of different types".to_string(),
-            ));
-        }
-        (Boolean, Boolean) => compare_boolean(left, right),
-        (UInt8, UInt8) => compare_primitives::<UInt8Type>(left, right),
-        (UInt16, UInt16) => compare_primitives::<UInt16Type>(left, right),
-        (UInt32, UInt32) => compare_primitives::<UInt32Type>(left, right),
-        (UInt64, UInt64) => compare_primitives::<UInt64Type>(left, right),
-        (Int8, Int8) => compare_primitives::<Int8Type>(left, right),
-        (Int16, Int16) => compare_primitives::<Int16Type>(left, right),
-        (Int32, Int32) => compare_primitives::<Int32Type>(left, right),
-        (Int64, Int64) => compare_primitives::<Int64Type>(left, right),
-        (Float16, Float16) => compare_primitives::<Float16Type>(left, right),
-        (Float32, Float32) => compare_primitives::<Float32Type>(left, right),
-        (Float64, Float64) => compare_primitives::<Float64Type>(left, right),
-        (Decimal128(_, _), Decimal128(_, _)) => {
-            compare_primitives::<Decimal128Type>(left, right)
-        }
-        (Decimal256(_, _), Decimal256(_, _)) => {
-            compare_primitives::<Decimal256Type>(left, right)
-        }
-        (Date32, Date32) => compare_primitives::<Date32Type>(left, right),
-        (Date64, Date64) => compare_primitives::<Date64Type>(left, right),
-        (Time32(Second), Time32(Second)) => {
-            compare_primitives::<Time32SecondType>(left, right)
-        }
-        (Time32(Millisecond), Time32(Millisecond)) => {
-            compare_primitives::<Time32MillisecondType>(left, right)
-        }
-        (Time64(Microsecond), Time64(Microsecond)) => {
-            compare_primitives::<Time64MicrosecondType>(left, right)
-        }
-        (Time64(Nanosecond), Time64(Nanosecond)) => {
-            compare_primitives::<Time64NanosecondType>(left, right)
-        }
-        (Timestamp(Second, _), Timestamp(Second, _)) => {
-            compare_primitives::<TimestampSecondType>(left, right)
-        }
-        (Timestamp(Millisecond, _), Timestamp(Millisecond, _)) => {
-            compare_primitives::<TimestampMillisecondType>(left, right)
-        }
-        (Timestamp(Microsecond, _), Timestamp(Microsecond, _)) => {
-            compare_primitives::<TimestampMicrosecondType>(left, right)
-        }
-        (Timestamp(Nanosecond, _), Timestamp(Nanosecond, _)) => {
-            compare_primitives::<TimestampNanosecondType>(left, right)
-        }
-        (Interval(YearMonth), Interval(YearMonth)) => {
-            compare_primitives::<IntervalYearMonthType>(left, right)
-        }
-        (Interval(DayTime), Interval(DayTime)) => {
-            compare_primitives::<IntervalDayTimeType>(left, right)
-        }
-        (Interval(MonthDayNano), Interval(MonthDayNano)) => {
-            compare_primitives::<IntervalMonthDayNanoType>(left, right)
-        }
-        (Duration(Second), Duration(Second)) => {
-            compare_primitives::<DurationSecondType>(left, right)
-        }
-        (Duration(Millisecond), Duration(Millisecond)) => {
-            compare_primitives::<DurationMillisecondType>(left, right)
-        }
-        (Duration(Microsecond), Duration(Microsecond)) => {
-            compare_primitives::<DurationMicrosecondType>(left, right)
-        }
-        (Duration(Nanosecond), Duration(Nanosecond)) => {
-            compare_primitives::<DurationNanosecondType>(left, right)
-        }
-        (Utf8, Utf8) => compare_string(left, right),
-        (LargeUtf8, LargeUtf8) => compare_string(left, right),
-        (
-            Dictionary(key_type_lhs, value_type_lhs),
-            Dictionary(key_type_rhs, value_type_rhs),
-        ) => {
-            if key_type_lhs != key_type_rhs || value_type_lhs != 
value_type_rhs {
-                return Err(ArrowError::InvalidArgumentError(
-                    "Can't compare arrays of different types".to_string(),
-                ));
-            }
-
-            let key_type_lhs = key_type_lhs.as_ref();
-            downcast_primitive! {
-                value_type_lhs.as_ref() => (cmp_dict_primitive_helper, 
key_type_lhs, left, right),
-                Utf8 => match key_type_lhs {
-                    UInt8 => compare_dict_string::<UInt8Type>(left, right),
-                    UInt16 => compare_dict_string::<UInt16Type>(left, right),
-                    UInt32 => compare_dict_string::<UInt32Type>(left, right),
-                    UInt64 => compare_dict_string::<UInt64Type>(left, right),
-                    Int8 => compare_dict_string::<Int8Type>(left, right),
-                    Int16 => compare_dict_string::<Int16Type>(left, right),
-                    Int32 => compare_dict_string::<Int32Type>(left, right),
-                    Int64 => compare_dict_string::<Int64Type>(left, right),
-                    lhs => {
-                        return Err(ArrowError::InvalidArgumentError(format!(
-                            "Dictionaries do not support keys of type {lhs:?}"
-                        )));
-                    }
-                },
-                t => {
-                    return Err(ArrowError::InvalidArgumentError(format!(
-                        "Dictionaries of value data type {t:?} are not 
supported"
-                    )));
-                }
-            }
-        }
+    use arrow_schema::DataType::*;
+    macro_rules! primitive_helper {
+        ($t:ty, $left:expr, $right:expr) => {
+            Ok(compare_primitive::<$t>($left, $right))
+        };
+    }
+    downcast_primitive! {
+        left.data_type(), right.data_type() => (primitive_helper, left, right),
+        (Boolean, Boolean) => Ok(compare_boolean(left, right)),
+        (Utf8, Utf8) => Ok(compare_bytes::<Utf8Type>(left, right)),
+        (LargeUtf8, LargeUtf8) => Ok(compare_bytes::<LargeUtf8Type>(left, 
right)),
+        (Binary, Binary) => Ok(compare_bytes::<BinaryType>(left, right)),
+        (LargeBinary, LargeBinary) => 
Ok(compare_bytes::<LargeBinaryType>(left, right)),
         (FixedSizeBinary(_), FixedSizeBinary(_)) => {
-            let left: FixedSizeBinaryArray = left.to_data().into();
-            let right: FixedSizeBinaryArray = right.to_data().into();
-
-            Box::new(move |i, j| left.value(i).cmp(right.value(j)))
-        }
-        (lhs, _) => {
-            return Err(ArrowError::InvalidArgumentError(format!(
-                "The data type type {lhs:?} has no natural order"
-            )));
-        }
-    })
+            let left = left.as_fixed_size_binary().clone();
+            let right = right.as_fixed_size_binary().clone();
+            Ok(Box::new(move |i, j| left.value(i).cmp(right.value(j))))
+        },
+        (Dictionary(l_key, _), Dictionary(r_key, _)) => {
+             macro_rules! dict_helper {
+                ($t:ty, $left:expr, $right:expr) => {
+                     compare_dict::<$t>($left, $right)
+                 };
+             }
+            downcast_integer! {
+                 l_key.as_ref(), r_key.as_ref() => (dict_helper, left, right),
+                 _ => unreachable!()
+             }
+        },
+        (lhs, rhs) => Err(ArrowError::InvalidArgumentError(match lhs == rhs {
+            true => format!("The data type type {lhs:?} has no natural order"),
+            false => "Can't compare arrays of different types".to_string(),
+        }))
+    }
 }
 
 #[cfg(test)]
 pub mod tests {
     use super::*;
     use arrow_array::{FixedSizeBinaryArray, Float64Array, Int32Array};
-    use arrow_buffer::i256;
+    use arrow_buffer::{i256, OffsetBuffer};
     use half::f16;
     use std::cmp::Ordering;
     use std::sync::Arc;
@@ -292,7 +150,7 @@ pub mod tests {
 
         let cmp = build_compare(&array, &array).unwrap();
 
-        assert_eq!(Ordering::Less, (cmp)(0, 1));
+        assert_eq!(Ordering::Less, cmp(0, 1));
     }
 
     #[test]
@@ -304,7 +162,7 @@ pub mod tests {
 
         let cmp = build_compare(&array1, &array2).unwrap();
 
-        assert_eq!(Ordering::Less, (cmp)(0, 0));
+        assert_eq!(Ordering::Less, cmp(0, 0));
     }
 
     #[test]
@@ -323,7 +181,7 @@ pub mod tests {
 
         let cmp = build_compare(&array1, &array2).unwrap();
 
-        assert_eq!(Ordering::Less, (cmp)(0, 0));
+        assert_eq!(Ordering::Less, cmp(0, 0));
     }
 
     #[test]
@@ -332,7 +190,7 @@ pub mod tests {
 
         let cmp = build_compare(&array, &array).unwrap();
 
-        assert_eq!(Ordering::Less, (cmp)(0, 1));
+        assert_eq!(Ordering::Less, cmp(0, 1));
     }
 
     #[test]
@@ -341,7 +199,7 @@ pub mod tests {
 
         let cmp = build_compare(&array, &array).unwrap();
 
-        assert_eq!(Ordering::Less, (cmp)(0, 1));
+        assert_eq!(Ordering::Less, cmp(0, 1));
     }
 
     #[test]
@@ -350,8 +208,8 @@ pub mod tests {
 
         let cmp = build_compare(&array, &array).unwrap();
 
-        assert_eq!(Ordering::Less, (cmp)(0, 1));
-        assert_eq!(Ordering::Equal, (cmp)(1, 1));
+        assert_eq!(Ordering::Less, cmp(0, 1));
+        assert_eq!(Ordering::Equal, cmp(1, 1));
     }
 
     #[test]
@@ -360,8 +218,8 @@ pub mod tests {
 
         let cmp = build_compare(&array, &array).unwrap();
 
-        assert_eq!(Ordering::Less, (cmp)(0, 1));
-        assert_eq!(Ordering::Greater, (cmp)(1, 0));
+        assert_eq!(Ordering::Less, cmp(0, 1));
+        assert_eq!(Ordering::Greater, cmp(1, 0));
     }
 
     #[test]
@@ -373,8 +231,8 @@ pub mod tests {
             .unwrap();
 
         let cmp = build_compare(&array, &array).unwrap();
-        assert_eq!(Ordering::Less, (cmp)(1, 0));
-        assert_eq!(Ordering::Greater, (cmp)(0, 2));
+        assert_eq!(Ordering::Less, cmp(1, 0));
+        assert_eq!(Ordering::Greater, cmp(0, 2));
     }
 
     #[test]
@@ -390,8 +248,8 @@ pub mod tests {
         .unwrap();
 
         let cmp = build_compare(&array, &array).unwrap();
-        assert_eq!(Ordering::Less, (cmp)(1, 0));
-        assert_eq!(Ordering::Greater, (cmp)(0, 2));
+        assert_eq!(Ordering::Less, cmp(1, 0));
+        assert_eq!(Ordering::Greater, cmp(0, 2));
     }
 
     #[test]
@@ -401,9 +259,9 @@ pub mod tests {
 
         let cmp = build_compare(&array, &array).unwrap();
 
-        assert_eq!(Ordering::Less, (cmp)(0, 1));
-        assert_eq!(Ordering::Equal, (cmp)(3, 4));
-        assert_eq!(Ordering::Greater, (cmp)(2, 3));
+        assert_eq!(Ordering::Less, cmp(0, 1));
+        assert_eq!(Ordering::Equal, cmp(3, 4));
+        assert_eq!(Ordering::Greater, cmp(2, 3));
     }
 
     #[test]
@@ -415,9 +273,9 @@ pub mod tests {
 
         let cmp = build_compare(&a1, &a2).unwrap();
 
-        assert_eq!(Ordering::Less, (cmp)(0, 0));
-        assert_eq!(Ordering::Equal, (cmp)(0, 3));
-        assert_eq!(Ordering::Greater, (cmp)(1, 3));
+        assert_eq!(Ordering::Less, cmp(0, 0));
+        assert_eq!(Ordering::Equal, cmp(0, 3));
+        assert_eq!(Ordering::Greater, cmp(1, 3));
     }
 
     #[test]
@@ -432,11 +290,11 @@ pub mod tests {
 
         let cmp = build_compare(&array1, &array2).unwrap();
 
-        assert_eq!(Ordering::Less, (cmp)(0, 0));
-        assert_eq!(Ordering::Less, (cmp)(0, 3));
-        assert_eq!(Ordering::Equal, (cmp)(3, 3));
-        assert_eq!(Ordering::Greater, (cmp)(3, 1));
-        assert_eq!(Ordering::Greater, (cmp)(3, 2));
+        assert_eq!(Ordering::Less, cmp(0, 0));
+        assert_eq!(Ordering::Less, cmp(0, 3));
+        assert_eq!(Ordering::Equal, cmp(3, 3));
+        assert_eq!(Ordering::Greater, cmp(3, 1));
+        assert_eq!(Ordering::Greater, cmp(3, 2));
     }
 
     #[test]
@@ -451,11 +309,11 @@ pub mod tests {
 
         let cmp = build_compare(&array1, &array2).unwrap();
 
-        assert_eq!(Ordering::Less, (cmp)(0, 0));
-        assert_eq!(Ordering::Less, (cmp)(0, 3));
-        assert_eq!(Ordering::Equal, (cmp)(3, 3));
-        assert_eq!(Ordering::Greater, (cmp)(3, 1));
-        assert_eq!(Ordering::Greater, (cmp)(3, 2));
+        assert_eq!(Ordering::Less, cmp(0, 0));
+        assert_eq!(Ordering::Less, cmp(0, 3));
+        assert_eq!(Ordering::Equal, cmp(3, 3));
+        assert_eq!(Ordering::Greater, cmp(3, 1));
+        assert_eq!(Ordering::Greater, cmp(3, 2));
     }
 
     #[test]
@@ -470,11 +328,11 @@ pub mod tests {
 
         let cmp = build_compare(&array1, &array2).unwrap();
 
-        assert_eq!(Ordering::Less, (cmp)(0, 0));
-        assert_eq!(Ordering::Less, (cmp)(0, 3));
-        assert_eq!(Ordering::Equal, (cmp)(3, 3));
-        assert_eq!(Ordering::Greater, (cmp)(3, 1));
-        assert_eq!(Ordering::Greater, (cmp)(3, 2));
+        assert_eq!(Ordering::Less, cmp(0, 0));
+        assert_eq!(Ordering::Less, cmp(0, 3));
+        assert_eq!(Ordering::Equal, cmp(3, 3));
+        assert_eq!(Ordering::Greater, cmp(3, 1));
+        assert_eq!(Ordering::Greater, cmp(3, 2));
     }
 
     #[test]
@@ -489,11 +347,11 @@ pub mod tests {
 
         let cmp = build_compare(&array1, &array2).unwrap();
 
-        assert_eq!(Ordering::Less, (cmp)(0, 0));
-        assert_eq!(Ordering::Less, (cmp)(0, 3));
-        assert_eq!(Ordering::Equal, (cmp)(3, 3));
-        assert_eq!(Ordering::Greater, (cmp)(3, 1));
-        assert_eq!(Ordering::Greater, (cmp)(3, 2));
+        assert_eq!(Ordering::Less, cmp(0, 0));
+        assert_eq!(Ordering::Less, cmp(0, 3));
+        assert_eq!(Ordering::Equal, cmp(3, 3));
+        assert_eq!(Ordering::Greater, cmp(3, 1));
+        assert_eq!(Ordering::Greater, cmp(3, 2));
     }
 
     #[test]
@@ -508,11 +366,11 @@ pub mod tests {
 
         let cmp = build_compare(&array1, &array2).unwrap();
 
-        assert_eq!(Ordering::Less, (cmp)(0, 0));
-        assert_eq!(Ordering::Less, (cmp)(0, 3));
-        assert_eq!(Ordering::Equal, (cmp)(3, 3));
-        assert_eq!(Ordering::Greater, (cmp)(3, 1));
-        assert_eq!(Ordering::Greater, (cmp)(3, 2));
+        assert_eq!(Ordering::Less, cmp(0, 0));
+        assert_eq!(Ordering::Less, cmp(0, 3));
+        assert_eq!(Ordering::Equal, cmp(3, 3));
+        assert_eq!(Ordering::Greater, cmp(3, 1));
+        assert_eq!(Ordering::Greater, cmp(3, 2));
     }
 
     #[test]
@@ -527,11 +385,11 @@ pub mod tests {
 
         let cmp = build_compare(&array1, &array2).unwrap();
 
-        assert_eq!(Ordering::Less, (cmp)(0, 0));
-        assert_eq!(Ordering::Less, (cmp)(0, 3));
-        assert_eq!(Ordering::Equal, (cmp)(3, 3));
-        assert_eq!(Ordering::Greater, (cmp)(3, 1));
-        assert_eq!(Ordering::Greater, (cmp)(3, 2));
+        assert_eq!(Ordering::Less, cmp(0, 0));
+        assert_eq!(Ordering::Less, cmp(0, 3));
+        assert_eq!(Ordering::Equal, cmp(3, 3));
+        assert_eq!(Ordering::Greater, cmp(3, 1));
+        assert_eq!(Ordering::Greater, cmp(3, 2));
     }
 
     #[test]
@@ -556,10 +414,28 @@ pub mod tests {
 
         let cmp = build_compare(&array1, &array2).unwrap();
 
-        assert_eq!(Ordering::Less, (cmp)(0, 0));
-        assert_eq!(Ordering::Less, (cmp)(0, 3));
-        assert_eq!(Ordering::Equal, (cmp)(3, 3));
-        assert_eq!(Ordering::Greater, (cmp)(3, 1));
-        assert_eq!(Ordering::Greater, (cmp)(3, 2));
+        assert_eq!(Ordering::Less, cmp(0, 0));
+        assert_eq!(Ordering::Less, cmp(0, 3));
+        assert_eq!(Ordering::Equal, cmp(3, 3));
+        assert_eq!(Ordering::Greater, cmp(3, 1));
+        assert_eq!(Ordering::Greater, cmp(3, 2));
+    }
+
+    fn test_bytes_impl<T: ByteArrayType>() {
+        let offsets = OffsetBuffer::from_lengths([3, 3, 1]);
+        let a = GenericByteArray::<T>::new(offsets, b"abcdefa".into(), None);
+        let cmp = build_compare(&a, &a).unwrap();
+
+        assert_eq!(Ordering::Less, cmp(0, 1));
+        assert_eq!(Ordering::Greater, cmp(0, 2));
+        assert_eq!(Ordering::Equal, cmp(1, 1));
+    }
+
+    #[test]
+    fn test_bytes() {
+        test_bytes_impl::<Utf8Type>();
+        test_bytes_impl::<LargeUtf8Type>();
+        test_bytes_impl::<BinaryType>();
+        test_bytes_impl::<LargeBinaryType>();
     }
 }

Reply via email to