jorgecarleitao commented on a change in pull request #8092:
URL: https://github.com/apache/arrow/pull/8092#discussion_r483900538



##########
File path: rust/arrow/src/array/ord.rs
##########
@@ -68,65 +80,233 @@ impl OrdArray for NullArray {
     }
 }
 
+macro_rules! float_ord_cmp {
+    ($NAME: ident, $T: ty) => {
+        #[inline]
+        fn $NAME(a: $T, b: $T) -> Ordering {
+            if a < b {
+                return Ordering::Less;
+            }
+            if a > b {
+                return Ordering::Greater;
+            }
+
+            // convert to bits with canonical pattern for NaN
+            let a = if a.is_nan() {
+                <$T>::NAN.to_bits()
+            } else {
+                a.to_bits()
+            };
+            let b = if b.is_nan() {
+                <$T>::NAN.to_bits()
+            } else {
+                b.to_bits()
+            };
+
+            if a == b {
+                // Equal or both NaN
+                Ordering::Equal
+            } else if a < b {
+                // (-0.0, 0.0) or (!NaN, NaN)
+                Ordering::Less
+            } else {
+                // (0.0, -0.0) or (NaN, !NaN)
+                Ordering::Greater
+            }
+        }
+    };
+}
+
+float_ord_cmp!(cmp_f64, f64);
+float_ord_cmp!(cmp_f32, f32);
+
+#[repr(transparent)]
+struct Float64ArrayAsOrdArray<'a>(&'a Float64Array);
+#[repr(transparent)]
+struct Float32ArrayAsOrdArray<'a>(&'a Float32Array);
+
+impl OrdArray for Float64ArrayAsOrdArray<'_> {
+    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
+        let a: f64 = self.0.value(i);
+        let b: f64 = self.0.value(j);
+
+        cmp_f64(a, b)
+    }
+}
+
+impl OrdArray for Float32ArrayAsOrdArray<'_> {
+    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
+        let a: f32 = self.0.value(i);
+        let b: f32 = self.0.value(j);
+
+        cmp_f32(a, b)
+    }
+}
+
+fn float32_as_ord_array<'a>(array: &'a ArrayRef) -> Box<dyn OrdArray + 'a> {
+    let float_array: &Float32Array = as_primitive_array::<Float32Type>(array);
+    Box::new(Float32ArrayAsOrdArray(float_array))
+}
+
+fn float64_as_ord_array<'a>(array: &'a ArrayRef) -> Box<dyn OrdArray + 'a> {
+    let float_array: &Float64Array = as_primitive_array::<Float64Type>(array);
+    Box::new(Float64ArrayAsOrdArray(float_array))
+}
+
+struct StringDictionaryArrayAsOrdArray<'a, T: ArrowDictionaryKeyType> {
+    dict_array: &'a DictionaryArray<T>,
+    values: StringArray,
+    keys: PrimitiveArray<T>,
+}
+
+impl<T: ArrowDictionaryKeyType> OrdArray for 
StringDictionaryArrayAsOrdArray<'_, T> {
+    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
+        let keys = &self.keys;
+
+        let a: T::Native = keys.value(i);
+        let b: T::Native = keys.value(j);
+
+        let dict = &self.values;
+
+        let sa = dict.value(a.to_usize().unwrap());
+        let sb = dict.value(b.to_usize().unwrap());
+
+        sa.cmp(sb)
+    }
+}
+
+fn string_dict_as_ord_array<'a, T: ArrowDictionaryKeyType>(
+    array: &'a ArrayRef,
+) -> Box<dyn OrdArray + 'a>
+where
+    T::Native: std::cmp::Ord,
+{
+    let dict_array = as_dictionary_array::<T>(array);
+    let keys = dict_array.keys_array();
+
+    let values = &dict_array.values();
+    let values = StringArray::from(values.data());
+
+    Box::new(StringDictionaryArrayAsOrdArray {
+        dict_array,
+        values,
+        keys,
+    })
+}
+
 /// Convert ArrayRef to OrdArray trait object
-pub fn as_ordarray(values: &ArrayRef) -> Result<&OrdArray> {
+pub fn as_ordarray<'a>(values: &'a ArrayRef) -> Result<Box<OrdArray + 'a>> {
     match values.data_type() {
-        DataType::Boolean => Ok(as_boolean_array(&values)),
-        DataType::Utf8 => Ok(as_string_array(&values)),
-        DataType::Null => Ok(as_null_array(&values)),
-        DataType::Int8 => Ok(as_primitive_array::<Int8Type>(&values)),
-        DataType::Int16 => Ok(as_primitive_array::<Int16Type>(&values)),
-        DataType::Int32 => Ok(as_primitive_array::<Int32Type>(&values)),
-        DataType::Int64 => Ok(as_primitive_array::<Int64Type>(&values)),
-        DataType::UInt8 => Ok(as_primitive_array::<UInt8Type>(&values)),
-        DataType::UInt16 => Ok(as_primitive_array::<UInt16Type>(&values)),
-        DataType::UInt32 => Ok(as_primitive_array::<UInt32Type>(&values)),
-        DataType::UInt64 => Ok(as_primitive_array::<UInt64Type>(&values)),
-        DataType::Date32(_) => Ok(as_primitive_array::<Date32Type>(&values)),
-        DataType::Date64(_) => Ok(as_primitive_array::<Date64Type>(&values)),
-        DataType::Time32(Second) => 
Ok(as_primitive_array::<Time32SecondType>(&values)),
-        DataType::Time32(Millisecond) => {
-            Ok(as_primitive_array::<Time32MillisecondType>(&values))
-        }
-        DataType::Time64(Microsecond) => {
-            Ok(as_primitive_array::<Time64MicrosecondType>(&values))
-        }
-        DataType::Time64(Nanosecond) => {
-            Ok(as_primitive_array::<Time64NanosecondType>(&values))
+        //DataType::Boolean => Ok(Box::new(as_boolean_array(&values))),

Review comment:
       Any particular reason for commenting this one?

##########
File path: rust/arrow/src/compute/kernels/sort.rs
##########
@@ -428,13 +596,63 @@ mod tests {
         assert!(output.equals(&expected))
     }
 
+    fn test_sort_string_dict_arrays<T: ArrowDictionaryKeyType>(

Review comment:
       nice :)

##########
File path: rust/arrow/src/array/ord.rs
##########
@@ -68,65 +80,233 @@ impl OrdArray for NullArray {
     }
 }
 
+macro_rules! float_ord_cmp {
+    ($NAME: ident, $T: ty) => {
+        #[inline]
+        fn $NAME(a: $T, b: $T) -> Ordering {
+            if a < b {
+                return Ordering::Less;
+            }
+            if a > b {
+                return Ordering::Greater;
+            }
+
+            // convert to bits with canonical pattern for NaN
+            let a = if a.is_nan() {
+                <$T>::NAN.to_bits()
+            } else {
+                a.to_bits()
+            };
+            let b = if b.is_nan() {
+                <$T>::NAN.to_bits()
+            } else {
+                b.to_bits()
+            };
+
+            if a == b {
+                // Equal or both NaN
+                Ordering::Equal
+            } else if a < b {
+                // (-0.0, 0.0) or (!NaN, NaN)
+                Ordering::Less
+            } else {
+                // (0.0, -0.0) or (NaN, !NaN)
+                Ordering::Greater
+            }
+        }
+    };
+}
+
+float_ord_cmp!(cmp_f64, f64);
+float_ord_cmp!(cmp_f32, f32);
+
+#[repr(transparent)]
+struct Float64ArrayAsOrdArray<'a>(&'a Float64Array);
+#[repr(transparent)]
+struct Float32ArrayAsOrdArray<'a>(&'a Float32Array);
+
+impl OrdArray for Float64ArrayAsOrdArray<'_> {
+    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
+        let a: f64 = self.0.value(i);
+        let b: f64 = self.0.value(j);
+
+        cmp_f64(a, b)
+    }
+}
+
+impl OrdArray for Float32ArrayAsOrdArray<'_> {
+    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
+        let a: f32 = self.0.value(i);
+        let b: f32 = self.0.value(j);
+
+        cmp_f32(a, b)
+    }
+}
+
+fn float32_as_ord_array<'a>(array: &'a ArrayRef) -> Box<dyn OrdArray + 'a> {
+    let float_array: &Float32Array = as_primitive_array::<Float32Type>(array);
+    Box::new(Float32ArrayAsOrdArray(float_array))
+}
+
+fn float64_as_ord_array<'a>(array: &'a ArrayRef) -> Box<dyn OrdArray + 'a> {
+    let float_array: &Float64Array = as_primitive_array::<Float64Type>(array);
+    Box::new(Float64ArrayAsOrdArray(float_array))
+}
+
+struct StringDictionaryArrayAsOrdArray<'a, T: ArrowDictionaryKeyType> {
+    dict_array: &'a DictionaryArray<T>,
+    values: StringArray,
+    keys: PrimitiveArray<T>,
+}
+
+impl<T: ArrowDictionaryKeyType> OrdArray for 
StringDictionaryArrayAsOrdArray<'_, T> {
+    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
+        let keys = &self.keys;
+
+        let a: T::Native = keys.value(i);
+        let b: T::Native = keys.value(j);
+
+        let dict = &self.values;
+
+        let sa = dict.value(a.to_usize().unwrap());

Review comment:
       what does `sa` stands for? `string_a`?

##########
File path: rust/datafusion/src/physical_plan/sort.rs
##########
@@ -223,4 +223,51 @@ mod tests {
 
         Ok(())
     }
+
+    #[test]
+    fn test_lex_sort_by_float() -> Result<()> {
+        let schema = test::aggr_test_schema();
+        let partitions = 4;
+        let path = test::create_partitioned_csv("aggregate_test_100.csv", 
partitions)?;
+        let csv =
+            CsvExec::try_new(&path, CsvReadOptions::new().schema(&schema), 
None, 1024)?;
+
+        let sort_exec = Arc::new(SortExec::try_new(
+            vec![
+                // c11 float32 column
+                PhysicalSortExpr {
+                    expr: col("c11"),
+                    options: SortOptions::default(),
+                },
+                // c12 float64 column
+                PhysicalSortExpr {
+                    expr: col("c12"),
+                    options: SortOptions::default(),
+                },
+            ],
+            Arc::new(MergeExec::new(Arc::new(csv), 2)),
+            2,
+        )?);
+
+        assert_eq!(DataType::Float32, 
*sort_exec.schema().field(10).data_type());
+        assert_eq!(DataType::Float64, 
*sort_exec.schema().field(11).data_type());
+
+        let result: Vec<RecordBatch> = test::execute(sort_exec)?;
+        assert_eq!(result.len(), 1);
+
+        let columns = result[0].columns();
+
+        assert_eq!(DataType::Float32, *columns[10].data_type());
+        assert_eq!(DataType::Float64, *columns[11].data_type());
+
+        let c11 = as_primitive_array::<Float32Type>(&columns[10]);
+        assert_eq!(c11.value(0), 0.028003037_f32);

Review comment:
       it is difficult to tell why this value. Wouldn't it be easier to use 
in-memory data instead of `aggregate_test_100.csv` to craft the values as we 
please in the test, and thus make it easier to verify our assumptions?

##########
File path: rust/arrow/src/compute/kernels/sort.rs
##########
@@ -148,58 +206,165 @@ impl Default for SortOptions {
     }
 }
 
-/// Sort primitive values, excluding floats
+/// Sort primitive values
 fn sort_primitive<T>(
     values: &ArrayRef,
-    value_indices: Vec<usize>,
+    value_indices: Vec<u32>,
     null_indices: Vec<u32>,
+    nan_indices: Vec<u32>,
     options: &SortOptions,
 ) -> Result<UInt32Array>
 where
     T: ArrowPrimitiveType,
     T::Native: std::cmp::PartialOrd,
 {
     let values = as_primitive_array::<T>(values);
+    sort_primitive_typed(values, value_indices, null_indices, nan_indices, 
options)
+}
+
+fn sort_primitive_typed<T>(

Review comment:
       IMO we should document this behavior (order of `nan null value`).
   
   I.e. the description in issue ARROW-9895 should be encapsulated in the code 
base, both in-line comments for the developer, documentation strings for the 
users in the logical plan, physical plan and kernel, and potentially also in 
the README.
   

##########
File path: rust/arrow/src/compute/kernels/sort.rs
##########
@@ -148,58 +206,165 @@ impl Default for SortOptions {
     }
 }
 
-/// Sort primitive values, excluding floats
+/// Sort primitive values
 fn sort_primitive<T>(
     values: &ArrayRef,
-    value_indices: Vec<usize>,
+    value_indices: Vec<u32>,
     null_indices: Vec<u32>,
+    nan_indices: Vec<u32>,
     options: &SortOptions,
 ) -> Result<UInt32Array>
 where
     T: ArrowPrimitiveType,
     T::Native: std::cmp::PartialOrd,
 {
     let values = as_primitive_array::<T>(values);
+    sort_primitive_typed(values, value_indices, null_indices, nan_indices, 
options)
+}
+
+fn sort_primitive_typed<T>(
+    values: &PrimitiveArray<T>,
+    value_indices: Vec<u32>,
+    null_indices: Vec<u32>,
+    nan_indices: Vec<u32>,
+    options: &SortOptions,
+) -> Result<UInt32Array>
+where
+    T: ArrowPrimitiveType,
+    T::Native: std::cmp::PartialOrd,
+{
     // create tuples that are used for sorting
     let mut valids = value_indices
         .into_iter()
-        .map(|index| (index as u32, values.value(index)))
+        .map(|index| (index, values.value(index as usize)))
         .collect::<Vec<(u32, T::Native)>>();
+
+    let valids_len = valids.len();
+
     let mut nulls = null_indices;
+    let mut nans = nan_indices;
+
     if !options.descending {
-        valids.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or_else(|| 
Ordering::Greater));
+        valids.sort_by(|a, b| a.1.partial_cmp(&b.1).expect("unexpected NaN"));
     } else {
-        valids.sort_by(|a, b| {
-            a.1.partial_cmp(&b.1)
-                .unwrap_or_else(|| Ordering::Greater)
-                .reverse()
-        });
+        valids.sort_by(|a, b| a.1.partial_cmp(&b.1).expect("unexpected 
NaN").reverse());
+        // reverse to keep a stable ordering
+        nans.reverse();
         nulls.reverse();
     }
-    // collect the order of valid tuples
-    let mut valid_indices: Vec<u32> = valids.iter().map(|tuple| 
tuple.0).collect();
 
-    if options.nulls_first {
-        nulls.append(&mut valid_indices);
-        return Ok(UInt32Array::from(nulls));
+    // collect results directly into a buffer instead of a vec to avoid 
another aligned allocation
+    let mut result = MutableBuffer::new(values.len() * 
std::mem::size_of::<u32>());
+    // sets len to capacity so we can access the whole buffer as a typed slice
+    result.resize(values.len() * std::mem::size_of::<u32>())?;
+    {
+        let append_valids = move |dst_slice: &mut [u32]| {
+            debug_assert_eq!(dst_slice.len(), valids_len);
+            dst_slice
+                .iter_mut()
+                .zip(valids.into_iter())
+                .for_each(|(dst, src)| *dst = src.0)
+        };
+
+        let result_slice: &mut [u32] = result.typed_data_mut();
+
+        debug_assert_eq!(result_slice.len(), nulls.len() + nans.len() + 
valids_len);
+
+        if options.nulls_first {

Review comment:
       I read this logic carefully and I agree with it. My only comment here is 
to place some comments. For example:
   
   ```
          if options.nulls_first {
               // nulls first
               result_slice[0..nulls.len()].copy_from_slice(&nulls);
               if !options.descending {
                   // valids next
                   append_valids(&mut result_slice[nulls.len()..nulls.len() + 
valids_len]);
                   // nans at the end
                   result_slice[nulls.len() + 
valids_len..].copy_from_slice(nans.as_slice());
               } else {
                   // nans next
                   result_slice[nulls.len()..nulls.len() + nans.len()]
                       .copy_from_slice(nans.as_slice());
                   // valids at the end
                   append_valids(&mut result_slice[nulls.len() + nans.len()..]);
               }
           } else {
               if !options.descending {
                   // valids first
                   append_valids(&mut result_slice[0..valids_len]);
                   // nans next
                   result_slice[valids_len..valids_len + nans.len()]
                       .copy_from_slice(nans.as_slice());
               } else {
                   // nans first
                   result_slice[0..nans.len()].copy_from_slice(nans.as_slice());
                   // valids next
                   append_valids(&mut result_slice[nans.len()..nans.len() + 
valids_len]);
               }
               // nulls at the end
               result_slice[valids_len + 
nans.len()..].copy_from_slice(nulls.as_slice())
           }
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to