This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 22255c2c2c fix: sort_batch function unsupported mixed types with list 
(#9410)
22255c2c2c is described below

commit 22255c2c2c17abea73e62e8eee71ac5c2078bc5b
Author: JasonLi <[email protected]>
AuthorDate: Mon Mar 4 19:07:17 2024 +0800

    fix: sort_batch function unsupported mixed types with list (#9410)
    
    * fix: sort_batch function unsupported mixed types with list
    
    * chore: add some SQL level end to end tests
    
    * refactor: Use RowConverter only when sorting multiple columns that 
contain a List type
    
    * Update datafusion/physical-plan/src/sorts/sort.rs
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/physical-plan/src/sorts/sort.rs   | 131 ++++++++++++++++++++++++++-
 datafusion/sqllogictest/test_files/order.slt | 107 ++++++++++++++++++++++
 2 files changed, 236 insertions(+), 2 deletions(-)

diff --git a/datafusion/physical-plan/src/sorts/sort.rs 
b/datafusion/physical-plan/src/sorts/sort.rs
index 5b0f2f3548..db352bb2c8 100644
--- a/datafusion/physical-plan/src/sorts/sort.rs
+++ b/datafusion/physical-plan/src/sorts/sort.rs
@@ -41,10 +41,13 @@ use crate::{
     SendableRecordBatchStream, Statistics,
 };
 
-use arrow::compute::{concat_batches, lexsort_to_indices, take};
+use arrow::compute::{concat_batches, lexsort_to_indices, take, SortColumn};
 use arrow::datatypes::SchemaRef;
 use arrow::ipc::reader::FileReader;
 use arrow::record_batch::RecordBatch;
+use arrow::row::{RowConverter, SortField};
+use arrow_array::{Array, UInt32Array};
+use arrow_schema::DataType;
 use datafusion_common::{exec_err, DataFusionError, Result};
 use datafusion_common_runtime::SpawnedTask;
 use datafusion_execution::disk_manager::RefCountedTempFile;
@@ -588,7 +591,13 @@ pub(crate) fn sort_batch(
         .map(|expr| expr.evaluate_to_sort_column(batch))
         .collect::<Result<Vec<_>>>()?;
 
-    let indices = lexsort_to_indices(&sort_columns, fetch)?;
+    let indices = if is_multi_column_with_lists(&sort_columns) {
+        // lex_sort_to_indices doesn't support List with more than one colum
+        // https://github.com/apache/arrow-rs/issues/5454
+        lexsort_to_indices_multi_columns(sort_columns, fetch)?
+    } else {
+        lexsort_to_indices(&sort_columns, fetch)?
+    };
 
     let columns = batch
         .columns()
@@ -599,6 +608,48 @@ pub(crate) fn sort_batch(
     Ok(RecordBatch::try_new(batch.schema(), columns)?)
 }
 
+#[inline]
+fn is_multi_column_with_lists(sort_columns: &[SortColumn]) -> bool {
+    sort_columns.iter().any(|c| {
+        matches!(
+            c.values.data_type(),
+            DataType::List(_) | DataType::LargeList(_) | 
DataType::FixedSizeList(_, _)
+        )
+    })
+}
+
+pub(crate) fn lexsort_to_indices_multi_columns(
+    sort_columns: Vec<SortColumn>,
+    limit: Option<usize>,
+) -> Result<UInt32Array> {
+    let (fields, columns) = sort_columns.into_iter().fold(
+        (vec![], vec![]),
+        |(mut fields, mut columns), sort_column| {
+            fields.push(SortField::new_with_options(
+                sort_column.values.data_type().clone(),
+                sort_column.options.unwrap_or_default(),
+            ));
+            columns.push(sort_column.values);
+            (fields, columns)
+        },
+    );
+
+    // TODO reuse converter and rows, refer to TopK.
+    let converter = RowConverter::new(fields)?;
+    let rows = converter.convert_columns(&columns)?;
+    let mut sort: Vec<_> = rows.iter().enumerate().collect();
+    sort.sort_unstable_by(|(_, a), (_, b)| a.cmp(b));
+
+    let mut len = rows.num_rows();
+    if let Some(limit) = limit {
+        len = limit.min(len);
+    }
+    let indices =
+        UInt32Array::from_iter_values(sort.iter().take(len).map(|(i, _)| *i as 
u32));
+
+    Ok(indices)
+}
+
 async fn spill_sorted_batches(
     batches: Vec<RecordBatch>,
     path: &Path,
@@ -1159,6 +1210,82 @@ mod tests {
         Ok(())
     }
 
+    #[tokio::test]
+    async fn test_lex_sort_by_mixed_types() -> Result<()> {
+        let task_ctx = Arc::new(TaskContext::default());
+        let schema = Arc::new(Schema::new(vec![
+            Field::new("a", DataType::Int32, true),
+            Field::new(
+                "b",
+                DataType::List(Arc::new(Field::new("item", DataType::Int32, 
true))),
+                true,
+            ),
+        ]));
+
+        // define data.
+        let batch = RecordBatch::try_new(
+            schema.clone(),
+            vec![
+                Arc::new(Int32Array::from(vec![Some(2), None, Some(1), 
Some(2)])),
+                Arc::new(ListArray::from_iter_primitive::<Int32Type, _, 
_>(vec![
+                    Some(vec![Some(3)]),
+                    Some(vec![Some(1)]),
+                    Some(vec![Some(6), None]),
+                    Some(vec![Some(5)]),
+                ])),
+            ],
+        )?;
+
+        let sort_exec = Arc::new(SortExec::new(
+            vec![
+                PhysicalSortExpr {
+                    expr: col("a", &schema)?,
+                    options: SortOptions {
+                        descending: false,
+                        nulls_first: true,
+                    },
+                },
+                PhysicalSortExpr {
+                    expr: col("b", &schema)?,
+                    options: SortOptions {
+                        descending: true,
+                        nulls_first: false,
+                    },
+                },
+            ],
+            Arc::new(MemoryExec::try_new(&[vec![batch]], schema.clone(), 
None)?),
+        ));
+
+        assert_eq!(DataType::Int32, *sort_exec.schema().field(0).data_type());
+        assert_eq!(
+            DataType::List(Arc::new(Field::new("item", DataType::Int32, 
true))),
+            *sort_exec.schema().field(1).data_type()
+        );
+
+        let result: Vec<RecordBatch> = collect(sort_exec.clone(), 
task_ctx).await?;
+        let metrics = sort_exec.metrics().unwrap();
+        assert!(metrics.elapsed_compute().unwrap() > 0);
+        assert_eq!(metrics.output_rows().unwrap(), 4);
+        assert_eq!(result.len(), 1);
+
+        let expected = RecordBatch::try_new(
+            schema,
+            vec![
+                Arc::new(Int32Array::from(vec![None, Some(1), Some(2), 
Some(2)])),
+                Arc::new(ListArray::from_iter_primitive::<Int32Type, _, 
_>(vec![
+                    Some(vec![Some(1)]),
+                    Some(vec![Some(6), None]),
+                    Some(vec![Some(5)]),
+                    Some(vec![Some(3)]),
+                ])),
+            ],
+        )?;
+
+        assert_eq!(expected, result[0]);
+
+        Ok(())
+    }
+
     #[tokio::test]
     async fn test_lex_sort_by_float() -> Result<()> {
         let task_ctx = Arc::new(TaskContext::default());
diff --git a/datafusion/sqllogictest/test_files/order.slt 
b/datafusion/sqllogictest/test_files/order.slt
index 2ea78448b9..f63179a369 100644
--- a/datafusion/sqllogictest/test_files/order.slt
+++ b/datafusion/sqllogictest/test_files/order.slt
@@ -784,3 +784,110 @@ SortPreservingMergeExec: [m@0 ASC NULLS LAST,t@1 ASC 
NULLS LAST]
 ----------------AggregateExec: mode=Partial, gby=[1 as Int64(1), t@0 as t], 
aggr=[], ordering_mode=PartiallySorted([0])
 ------------------ProjectionExec: expr=[column1@0 as t]
 --------------------ValuesExec
+
+#####
+# Multi column sorting with lists
+#####
+
+statement ok
+create table foo as values (2, [0]), (4, [1]), (2, [6]), (1, [2, 5]), (0, 
[3]), (3, [4]), (2, [2, 5]), (2, [7]);
+
+query I?
+select column1, column2 from foo ORDER BY column1, column2;
+----
+0 [3]
+1 [2, 5]
+2 [0]
+2 [2, 5]
+2 [6]
+2 [7]
+3 [4]
+4 [1]
+
+query I?
+select column1, column2 from foo ORDER BY column1 desc, column2;
+----
+4 [1]
+3 [4]
+2 [0]
+2 [2, 5]
+2 [6]
+2 [7]
+1 [2, 5]
+0 [3]
+
+query I?
+select column1, column2 from foo ORDER BY column1, column2 desc;
+----
+0 [3]
+1 [2, 5]
+2 [7]
+2 [6]
+2 [2, 5]
+2 [0]
+3 [4]
+4 [1]
+
+query I?
+select column1, column2 from foo ORDER BY column1 desc, column2 desc;
+----
+4 [1]
+3 [4]
+2 [7]
+2 [6]
+2 [2, 5]
+2 [0]
+1 [2, 5]
+0 [3]
+
+query ?I
+select column2, column1 from foo ORDER BY column2, column1;
+----
+[0] 2
+[1] 4
+[2, 5] 1
+[2, 5] 2
+[3] 0
+[4] 3
+[6] 2
+[7] 2
+
+query ?I
+select column2, column1 from foo ORDER BY column2 desc, column1;
+----
+[7] 2
+[6] 2
+[4] 3
+[3] 0
+[2, 5] 1
+[2, 5] 2
+[1] 4
+[0] 2
+
+query ?I
+select column2, column1 from foo ORDER BY column2, column1 desc;
+----
+[0] 2
+[1] 4
+[2, 5] 2
+[2, 5] 1
+[3] 0
+[4] 3
+[6] 2
+[7] 2
+
+query ?I
+select column2, column1 from foo ORDER BY column2 desc, column1 desc;
+----
+[7] 2
+[6] 2
+[4] 3
+[3] 0
+[2, 5] 2
+[2, 5] 1
+[1] 4
+[0] 2
+
+# Cleanup
+statement ok
+drop table foo;

Reply via email to