This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 036dc48ae1 Fix window sort removal wrong operator. (#7811)
036dc48ae1 is described below

commit 036dc48ae17287939e5f9ba54e7da175ceb60910
Author: Mustafa Akur <[email protected]>
AuthorDate: Thu Oct 12 21:49:00 2023 +0300

    Fix window sort removal wrong operator. (#7811)
---
 .../core/src/physical_optimizer/enforce_sorting.rs | 52 ++++++++++++++++++++--
 .../core/src/physical_optimizer/test_utils.rs      |  8 ++++
 2 files changed, 56 insertions(+), 4 deletions(-)

diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs 
b/datafusion/core/src/physical_optimizer/enforce_sorting.rs
index 95ec1973d0..f84a05f0fd 100644
--- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs
+++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs
@@ -62,6 +62,7 @@ use datafusion_physical_expr::utils::{
 };
 use datafusion_physical_expr::{PhysicalSortExpr, PhysicalSortRequirement};
 
+use datafusion_physical_plan::repartition::RepartitionExec;
 use itertools::izip;
 
 /// This rule inspects [`SortExec`]'s in the given physical plan and removes 
the
@@ -566,7 +567,6 @@ fn analyze_window_sort_removal(
     );
     let mut window_child =
         remove_corresponding_sort_from_sub_plan(sort_tree, 
requires_single_partition)?;
-
     let (window_expr, new_window) =
         if let Some(exec) = 
window_exec.as_any().downcast_ref::<BoundedWindowAggExec>() {
             (
@@ -704,8 +704,18 @@ fn remove_corresponding_sort_from_sub_plan(
             children[item.idx] =
                 remove_corresponding_sort_from_sub_plan(item, 
requires_single_partition)?;
         }
+        // Replace with variants that do not preserve order.
         if is_sort_preserving_merge(plan) {
             children[0].clone()
+        } else if let Some(repartition) = 
plan.as_any().downcast_ref::<RepartitionExec>()
+        {
+            Arc::new(
+                RepartitionExec::try_new(
+                    children[0].clone(),
+                    repartition.partitioning().clone(),
+                )?
+                .with_preserve_order(false),
+            )
         } else {
             plan.clone().with_new_children(children)?
         }
@@ -758,7 +768,7 @@ mod tests {
         coalesce_partitions_exec, filter_exec, global_limit_exec, 
hash_join_exec,
         limit_exec, local_limit_exec, memory_exec, parquet_exec, 
parquet_exec_sorted,
         repartition_exec, sort_exec, sort_expr, sort_expr_options, 
sort_merge_join_exec,
-        sort_preserving_merge_exec, union_exec,
+        sort_preserving_merge_exec, spr_repartition_exec, union_exec,
     };
     use crate::physical_optimizer::utils::get_plan_string;
     use crate::physical_plan::repartition::RepartitionExec;
@@ -1635,14 +1645,16 @@ mod tests {
         // During the removal of `SortExec`s, it should be able to remove the
         // corresponding SortExecs together. Also, the inputs of these 
`SortExec`s
         // are not necessarily the same to be able to remove them.
-        let expected_input = ["BoundedWindowAggExec: wdw=[count: Ok(Field { 
name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: 
false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: 
Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]",
+        let expected_input = [
+            "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", 
data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: 
{} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), 
end_bound: CurrentRow }], mode=[Sorted]",
             "  SortPreservingMergeExec: [nullable_col@0 DESC NULLS LAST]",
             "    UnionExec",
             "      SortExec: expr=[nullable_col@0 DESC NULLS LAST]",
             "        ParquetExec: file_groups={1 group: [[x]]}, 
projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 
ASC, non_nullable_col@1 ASC]",
             "      SortExec: expr=[nullable_col@0 DESC NULLS LAST]",
             "        ParquetExec: file_groups={1 group: [[x]]}, 
projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 
ASC]"];
-        let expected_optimized = ["WindowAggExec: wdw=[count: Ok(Field { name: 
\"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: 
false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: 
CurrentRow, end_bound: Following(NULL) }]",
+        let expected_optimized = [
+            "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: 
Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), 
frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: 
Following(NULL) }]",
             "  SortPreservingMergeExec: [nullable_col@0 ASC]",
             "    UnionExec",
             "      ParquetExec: file_groups={1 group: [[x]]}, 
projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 
ASC, non_nullable_col@1 ASC]",
@@ -2234,4 +2246,36 @@ mod tests {
         assert_optimized!(expected_input, expected_optimized, physical_plan, 
false);
         Ok(())
     }
+
+    #[tokio::test]
+    async fn test_window_multi_layer_requirement() -> Result<()> {
+        let schema = create_test_schema3()?;
+        let sort_exprs = vec![sort_expr("a", &schema), sort_expr("b", 
&schema)];
+        let source = csv_exec_sorted(&schema, vec![], false);
+        let sort = sort_exec(sort_exprs.clone(), source);
+        let repartition = repartition_exec(sort);
+        let repartition = spr_repartition_exec(repartition);
+        let spm = sort_preserving_merge_exec(sort_exprs.clone(), repartition);
+
+        let physical_plan = bounded_window_exec("a", sort_exprs, spm);
+
+        let expected_input = [
+            "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", 
data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: 
{} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), 
end_bound: CurrentRow }], mode=[Sorted]",
+            "  SortPreservingMergeExec: [a@0 ASC,b@1 ASC]",
+            "    SortPreservingRepartitionExec: 
partitioning=RoundRobinBatch(10), input_partitions=10",
+            "      RepartitionExec: partitioning=RoundRobinBatch(10), 
input_partitions=1",
+            "        SortExec: expr=[a@0 ASC,b@1 ASC]",
+            "          CsvExec: file_groups={1 group: [[x]]}, projection=[a, 
b, c, d, e], has_header=false",
+        ];
+        let expected_optimized = [
+            "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", 
data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: 
{} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), 
end_bound: CurrentRow }], mode=[Sorted]",
+            "  SortExec: expr=[a@0 ASC,b@1 ASC]",
+            "    CoalescePartitionsExec",
+            "      RepartitionExec: partitioning=RoundRobinBatch(10), 
input_partitions=10",
+            "        RepartitionExec: partitioning=RoundRobinBatch(10), 
input_partitions=1",
+            "          CsvExec: file_groups={1 group: [[x]]}, projection=[a, 
b, c, d, e], has_header=false",
+        ];
+        assert_optimized!(expected_input, expected_optimized, physical_plan, 
false);
+        Ok(())
+    }
 }
diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs 
b/datafusion/core/src/physical_optimizer/test_utils.rs
index 0915fdbf1c..9f966990b8 100644
--- a/datafusion/core/src/physical_optimizer/test_utils.rs
+++ b/datafusion/core/src/physical_optimizer/test_utils.rs
@@ -324,6 +324,14 @@ pub fn repartition_exec(input: Arc<dyn ExecutionPlan>) -> 
Arc<dyn ExecutionPlan>
     Arc::new(RepartitionExec::try_new(input, 
Partitioning::RoundRobinBatch(10)).unwrap())
 }
 
+pub fn spr_repartition_exec(input: Arc<dyn ExecutionPlan>) -> Arc<dyn 
ExecutionPlan> {
+    Arc::new(
+        RepartitionExec::try_new(input, Partitioning::RoundRobinBatch(10))
+            .unwrap()
+            .with_preserve_order(true),
+    )
+}
+
 pub fn aggregate_exec(input: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> 
{
     let schema = input.schema();
     Arc::new(

Reply via email to