This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 22255c2c2c fix: sort_batch function unsupported mixed types with list
(#9410)
22255c2c2c is described below
commit 22255c2c2c17abea73e62e8eee71ac5c2078bc5b
Author: JasonLi <[email protected]>
AuthorDate: Mon Mar 4 19:07:17 2024 +0800
fix: sort_batch function unsupported mixed types with list (#9410)
* fix: sort_batch function unsupported mixed types with list
* chore: add some SQL level end to end tests
* refactor: Use RowConverter only when sorting multiple columns that
contain a List type
* Update datafusion/physical-plan/src/sorts/sort.rs
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/physical-plan/src/sorts/sort.rs | 131 ++++++++++++++++++++++++++-
datafusion/sqllogictest/test_files/order.slt | 107 ++++++++++++++++++++++
2 files changed, 236 insertions(+), 2 deletions(-)
diff --git a/datafusion/physical-plan/src/sorts/sort.rs
b/datafusion/physical-plan/src/sorts/sort.rs
index 5b0f2f3548..db352bb2c8 100644
--- a/datafusion/physical-plan/src/sorts/sort.rs
+++ b/datafusion/physical-plan/src/sorts/sort.rs
@@ -41,10 +41,13 @@ use crate::{
SendableRecordBatchStream, Statistics,
};
-use arrow::compute::{concat_batches, lexsort_to_indices, take};
+use arrow::compute::{concat_batches, lexsort_to_indices, take, SortColumn};
use arrow::datatypes::SchemaRef;
use arrow::ipc::reader::FileReader;
use arrow::record_batch::RecordBatch;
+use arrow::row::{RowConverter, SortField};
+use arrow_array::{Array, UInt32Array};
+use arrow_schema::DataType;
use datafusion_common::{exec_err, DataFusionError, Result};
use datafusion_common_runtime::SpawnedTask;
use datafusion_execution::disk_manager::RefCountedTempFile;
@@ -588,7 +591,13 @@ pub(crate) fn sort_batch(
.map(|expr| expr.evaluate_to_sort_column(batch))
.collect::<Result<Vec<_>>>()?;
- let indices = lexsort_to_indices(&sort_columns, fetch)?;
+ let indices = if is_multi_column_with_lists(&sort_columns) {
+ // lex_sort_to_indices doesn't support List with more than one colum
+ // https://github.com/apache/arrow-rs/issues/5454
+ lexsort_to_indices_multi_columns(sort_columns, fetch)?
+ } else {
+ lexsort_to_indices(&sort_columns, fetch)?
+ };
let columns = batch
.columns()
@@ -599,6 +608,48 @@ pub(crate) fn sort_batch(
Ok(RecordBatch::try_new(batch.schema(), columns)?)
}
+#[inline]
+fn is_multi_column_with_lists(sort_columns: &[SortColumn]) -> bool {
+ sort_columns.iter().any(|c| {
+ matches!(
+ c.values.data_type(),
+ DataType::List(_) | DataType::LargeList(_) |
DataType::FixedSizeList(_, _)
+ )
+ })
+}
+
+pub(crate) fn lexsort_to_indices_multi_columns(
+ sort_columns: Vec<SortColumn>,
+ limit: Option<usize>,
+) -> Result<UInt32Array> {
+ let (fields, columns) = sort_columns.into_iter().fold(
+ (vec![], vec![]),
+ |(mut fields, mut columns), sort_column| {
+ fields.push(SortField::new_with_options(
+ sort_column.values.data_type().clone(),
+ sort_column.options.unwrap_or_default(),
+ ));
+ columns.push(sort_column.values);
+ (fields, columns)
+ },
+ );
+
+ // TODO reuse converter and rows, refer to TopK.
+ let converter = RowConverter::new(fields)?;
+ let rows = converter.convert_columns(&columns)?;
+ let mut sort: Vec<_> = rows.iter().enumerate().collect();
+ sort.sort_unstable_by(|(_, a), (_, b)| a.cmp(b));
+
+ let mut len = rows.num_rows();
+ if let Some(limit) = limit {
+ len = limit.min(len);
+ }
+ let indices =
+ UInt32Array::from_iter_values(sort.iter().take(len).map(|(i, _)| *i as
u32));
+
+ Ok(indices)
+}
+
async fn spill_sorted_batches(
batches: Vec<RecordBatch>,
path: &Path,
@@ -1159,6 +1210,82 @@ mod tests {
Ok(())
}
+ #[tokio::test]
+ async fn test_lex_sort_by_mixed_types() -> Result<()> {
+ let task_ctx = Arc::new(TaskContext::default());
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("a", DataType::Int32, true),
+ Field::new(
+ "b",
+ DataType::List(Arc::new(Field::new("item", DataType::Int32,
true))),
+ true,
+ ),
+ ]));
+
+ // define data.
+ let batch = RecordBatch::try_new(
+ schema.clone(),
+ vec![
+ Arc::new(Int32Array::from(vec![Some(2), None, Some(1),
Some(2)])),
+ Arc::new(ListArray::from_iter_primitive::<Int32Type, _,
_>(vec![
+ Some(vec![Some(3)]),
+ Some(vec![Some(1)]),
+ Some(vec![Some(6), None]),
+ Some(vec![Some(5)]),
+ ])),
+ ],
+ )?;
+
+ let sort_exec = Arc::new(SortExec::new(
+ vec![
+ PhysicalSortExpr {
+ expr: col("a", &schema)?,
+ options: SortOptions {
+ descending: false,
+ nulls_first: true,
+ },
+ },
+ PhysicalSortExpr {
+ expr: col("b", &schema)?,
+ options: SortOptions {
+ descending: true,
+ nulls_first: false,
+ },
+ },
+ ],
+ Arc::new(MemoryExec::try_new(&[vec![batch]], schema.clone(),
None)?),
+ ));
+
+ assert_eq!(DataType::Int32, *sort_exec.schema().field(0).data_type());
+ assert_eq!(
+ DataType::List(Arc::new(Field::new("item", DataType::Int32,
true))),
+ *sort_exec.schema().field(1).data_type()
+ );
+
+ let result: Vec<RecordBatch> = collect(sort_exec.clone(),
task_ctx).await?;
+ let metrics = sort_exec.metrics().unwrap();
+ assert!(metrics.elapsed_compute().unwrap() > 0);
+ assert_eq!(metrics.output_rows().unwrap(), 4);
+ assert_eq!(result.len(), 1);
+
+ let expected = RecordBatch::try_new(
+ schema,
+ vec![
+ Arc::new(Int32Array::from(vec![None, Some(1), Some(2),
Some(2)])),
+ Arc::new(ListArray::from_iter_primitive::<Int32Type, _,
_>(vec![
+ Some(vec![Some(1)]),
+ Some(vec![Some(6), None]),
+ Some(vec![Some(5)]),
+ Some(vec![Some(3)]),
+ ])),
+ ],
+ )?;
+
+ assert_eq!(expected, result[0]);
+
+ Ok(())
+ }
+
#[tokio::test]
async fn test_lex_sort_by_float() -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
diff --git a/datafusion/sqllogictest/test_files/order.slt
b/datafusion/sqllogictest/test_files/order.slt
index 2ea78448b9..f63179a369 100644
--- a/datafusion/sqllogictest/test_files/order.slt
+++ b/datafusion/sqllogictest/test_files/order.slt
@@ -784,3 +784,110 @@ SortPreservingMergeExec: [m@0 ASC NULLS LAST,t@1 ASC
NULLS LAST]
----------------AggregateExec: mode=Partial, gby=[1 as Int64(1), t@0 as t],
aggr=[], ordering_mode=PartiallySorted([0])
------------------ProjectionExec: expr=[column1@0 as t]
--------------------ValuesExec
+
+#####
+# Multi column sorting with lists
+#####
+
+statement ok
+create table foo as values (2, [0]), (4, [1]), (2, [6]), (1, [2, 5]), (0,
[3]), (3, [4]), (2, [2, 5]), (2, [7]);
+
+query I?
+select column1, column2 from foo ORDER BY column1, column2;
+----
+0 [3]
+1 [2, 5]
+2 [0]
+2 [2, 5]
+2 [6]
+2 [7]
+3 [4]
+4 [1]
+
+query I?
+select column1, column2 from foo ORDER BY column1 desc, column2;
+----
+4 [1]
+3 [4]
+2 [0]
+2 [2, 5]
+2 [6]
+2 [7]
+1 [2, 5]
+0 [3]
+
+query I?
+select column1, column2 from foo ORDER BY column1, column2 desc;
+----
+0 [3]
+1 [2, 5]
+2 [7]
+2 [6]
+2 [2, 5]
+2 [0]
+3 [4]
+4 [1]
+
+query I?
+select column1, column2 from foo ORDER BY column1 desc, column2 desc;
+----
+4 [1]
+3 [4]
+2 [7]
+2 [6]
+2 [2, 5]
+2 [0]
+1 [2, 5]
+0 [3]
+
+query ?I
+select column2, column1 from foo ORDER BY column2, column1;
+----
+[0] 2
+[1] 4
+[2, 5] 1
+[2, 5] 2
+[3] 0
+[4] 3
+[6] 2
+[7] 2
+
+query ?I
+select column2, column1 from foo ORDER BY column2 desc, column1;
+----
+[7] 2
+[6] 2
+[4] 3
+[3] 0
+[2, 5] 1
+[2, 5] 2
+[1] 4
+[0] 2
+
+query ?I
+select column2, column1 from foo ORDER BY column2, column1 desc;
+----
+[0] 2
+[1] 4
+[2, 5] 2
+[2, 5] 1
+[3] 0
+[4] 3
+[6] 2
+[7] 2
+
+query ?I
+select column2, column1 from foo ORDER BY column2 desc, column1 desc;
+----
+[7] 2
+[6] 2
+[4] 3
+[3] 0
+[2, 5] 2
+[2, 5] 1
+[1] 4
+[0] 2
+
+# Cleanup
+statement ok
+drop table foo;