rluvaton commented on code in PR #19494:
URL: https://github.com/apache/datafusion/pull/19494#discussion_r2650896461


##########
datafusion/physical-plan/src/sorts/sort.rs:
##########
@@ -2402,4 +2454,757 @@ mod tests {
 
         Ok((sorted_batches, metrics))
     }
+
+    // ========================================================================
+    // Tests for sort_batch_chunked()
+    // ========================================================================
+
+    #[tokio::test]
+    async fn test_sort_batch_chunked_basic() -> Result<()> {
+        let schema = Arc::new(Schema::new(vec![Field::new("a", 
DataType::Int32, false)]));
+
+        // Create a batch with 1000 rows
+        let mut values: Vec<i32> = (0..1000).collect();
+        // Shuffle to make it unsorted
+        values.reverse();
+
+        let batch = RecordBatch::try_new(
+            Arc::clone(&schema),
+            vec![Arc::new(Int32Array::from(values))],
+        )?;
+
+        let expressions: LexOrdering =
+            [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 
0)))].into();
+
+        // Sort with batch_size = 250
+        let result_batches = sort_batch_chunked(&batch, &expressions, 250)?;
+
+        // Verify 4 batches are returned
+        assert_eq!(result_batches.len(), 4);
+
+        // Verify each batch has <= 250 rows
+        let mut total_rows = 0;
+        for (i, batch) in result_batches.iter().enumerate() {
+            assert!(
+                batch.num_rows() <= 250,
+                "Batch {} has {} rows, expected <= 250",
+                i,
+                batch.num_rows()
+            );
+            total_rows += batch.num_rows();
+        }
+
+        // Verify total row count matches input
+        assert_eq!(total_rows, 1000);
+
+        // Verify data is correctly sorted across all chunks
+        let concatenated = concat_batches(&schema, &result_batches)?;
+        let array = as_primitive_array::<Int32Type>(concatenated.column(0))?;
+        for i in 0..array.len() - 1 {
+            assert!(
+                array.value(i) <= array.value(i + 1),
+                "Array not sorted at position {}: {} > {}",
+                i,
+                array.value(i),
+                array.value(i + 1)
+            );
+        }
+        assert_eq!(array.value(0), 0);
+        assert_eq!(array.value(array.len() - 1), 999);
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_sort_batch_chunked_smaller_than_batch_size() -> Result<()> {
+        let schema = Arc::new(Schema::new(vec![Field::new("a", 
DataType::Int32, false)]));
+
+        // Create a batch with 50 rows
+        let values: Vec<i32> = (0..50).rev().collect();
+        let batch = RecordBatch::try_new(
+            Arc::clone(&schema),
+            vec![Arc::new(Int32Array::from(values))],
+        )?;
+
+        let expressions: LexOrdering =
+            [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 
0)))].into();
+
+        // Sort with batch_size = 100
+        let result_batches = sort_batch_chunked(&batch, &expressions, 100)?;
+
+        // Should return exactly 1 batch
+        assert_eq!(result_batches.len(), 1);
+        assert_eq!(result_batches[0].num_rows(), 50);
+
+        // Verify it's correctly sorted
+        let array = 
as_primitive_array::<Int32Type>(result_batches[0].column(0))?;
+        for i in 0..array.len() - 1 {
+            assert!(array.value(i) <= array.value(i + 1));
+        }
+        assert_eq!(array.value(0), 0);
+        assert_eq!(array.value(49), 49);
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_sort_batch_chunked_exact_multiple() -> Result<()> {
+        let schema = Arc::new(Schema::new(vec![Field::new("a", 
DataType::Int32, false)]));
+
+        // Create a batch with 1000 rows
+        let values: Vec<i32> = (0..1000).rev().collect();
+        let batch = RecordBatch::try_new(
+            Arc::clone(&schema),
+            vec![Arc::new(Int32Array::from(values))],
+        )?;
+
+        let expressions: LexOrdering =
+            [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 
0)))].into();
+
+        // Sort with batch_size = 100
+        let result_batches = sort_batch_chunked(&batch, &expressions, 100)?;
+
+        // Should return exactly 10 batches of 100 rows each
+        assert_eq!(result_batches.len(), 10);
+        for batch in &result_batches {
+            assert_eq!(batch.num_rows(), 100);
+        }
+
+        // Verify sorted correctly across all batches
+        let concatenated = concat_batches(&schema, &result_batches)?;
+        let array = as_primitive_array::<Int32Type>(concatenated.column(0))?;
+        for i in 0..array.len() - 1 {
+            assert!(array.value(i) <= array.value(i + 1));
+        }
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_sort_batch_chunked_with_nulls() -> Result<()> {
+        let schema = Arc::new(Schema::new(vec![Field::new("a", 
DataType::Int32, true)]));
+
+        // Create a batch with nulls
+        let values = Int32Array::from(vec![
+            Some(5),
+            None,
+            Some(2),
+            Some(8),
+            None,
+            Some(1),
+            Some(9),
+            None,
+            Some(3),
+            Some(7),
+        ]);
+        let batch = RecordBatch::try_new(Arc::clone(&schema), 
vec![Arc::new(values)])?;
+
+        // Test with nulls_first = true
+        {
+            let expressions: LexOrdering = [PhysicalSortExpr {
+                expr: Arc::new(Column::new("a", 0)),
+                options: SortOptions {
+                    descending: false,
+                    nulls_first: true,
+                },
+            }]
+            .into();
+
+            let result_batches = sort_batch_chunked(&batch, &expressions, 4)?;
+            let concatenated = concat_batches(&schema, &result_batches)?;
+            let array = 
as_primitive_array::<Int32Type>(concatenated.column(0))?;
+
+            // First 3 should be null
+            assert!(array.is_null(0));
+            assert!(array.is_null(1));
+            assert!(array.is_null(2));
+            // Then sorted values
+            assert_eq!(array.value(3), 1);
+            assert_eq!(array.value(4), 2);
+        }
+
+        // Test with nulls_first = false
+        {
+            let expressions: LexOrdering = [PhysicalSortExpr {
+                expr: Arc::new(Column::new("a", 0)),
+                options: SortOptions {
+                    descending: false,
+                    nulls_first: false,
+                },
+            }]
+            .into();
+
+            let result_batches = sort_batch_chunked(&batch, &expressions, 4)?;
+            let concatenated = concat_batches(&schema, &result_batches)?;
+            let array = 
as_primitive_array::<Int32Type>(concatenated.column(0))?;
+
+            // First should be 1
+            assert_eq!(array.value(0), 1);
+            // Last 3 should be null
+            assert!(array.is_null(7));
+            assert!(array.is_null(8));
+            assert!(array.is_null(9));
+        }
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_sort_batch_chunked_multi_column() -> Result<()> {

Review Comment:
   This test the sorting



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to