mustafasrepo commented on code in PR #4928:
URL: https://github.com/apache/arrow-datafusion/pull/4928#discussion_r1071819266


##########
datafusion/core/src/physical_optimizer/sort_enforcement.rs:
##########
@@ -699,214 +713,147 @@ mod tests {
                 Arc::new(WindowFrame::new(true)),
                 schema.as_ref(),
             )?],
-            filter_exec.clone(),
-            filter_exec.schema(),
+            filter.clone(),
+            filter.schema(),
             vec![],
             Some(sort_exprs),
         )?) as Arc<dyn ExecutionPlan>;
-        let physical_plan = window_agg_exec;
-        let formatted = 
displayable(physical_plan.as_ref()).indent().to_string();
-        let expected = {
-            vec![
-                "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", 
data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: 
{} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), 
end_bound: CurrentRow }]",
-                "  FilterExec: NOT non_nullable_col@1",
-                "    SortExec: [non_nullable_col@1 ASC NULLS LAST]",
-                "      WindowAggExec: wdw=[count: Ok(Field { name: \"count\", 
data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: 
{} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), 
end_bound: CurrentRow }]",
-                "        SortExec: [non_nullable_col@1 DESC]",
-                "          MemoryExec: partitions=0, partition_sizes=[]",
-            ]
-        };
-        let actual: Vec<&str> = formatted.trim().lines().collect();
-        assert_eq!(
-            expected, actual,
-            "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
-        );
-        let optimized_physical_plan =
-            EnforceSorting::new().optimize(physical_plan, 
state.config_options())?;
-        let formatted = displayable(optimized_physical_plan.as_ref())
-            .indent()
-            .to_string();
-        let expected = {
-            vec![
-                "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", 
data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: 
{} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: 
Following(NULL) }]",
-                "  FilterExec: NOT non_nullable_col@1",
-                "    WindowAggExec: wdw=[count: Ok(Field { name: \"count\", 
data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: 
{} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), 
end_bound: CurrentRow }]",
-                "      SortExec: [non_nullable_col@1 DESC]",
-                "        MemoryExec: partitions=0, partition_sizes=[]",
-            ]
-        };
-        let actual: Vec<&str> = formatted.trim().lines().collect();
-        assert_eq!(
-            expected, actual,
-            "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
-        );
+
+        let expected_input = vec![
+            "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: 
Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), 
frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: 
CurrentRow }]",
+            "  FilterExec: NOT non_nullable_col@1",
+            "    SortExec: [non_nullable_col@1 ASC NULLS LAST]",
+            "      WindowAggExec: wdw=[count: Ok(Field { name: \"count\", 
data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: 
{} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), 
end_bound: CurrentRow }]",
+            "        SortExec: [non_nullable_col@1 DESC]",
+            "          MemoryExec: partitions=0, partition_sizes=[]",
+        ];
+
+        let expected_optimized = vec![
+            "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: 
Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), 
frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: 
Following(NULL) }]",
+            "  FilterExec: NOT non_nullable_col@1",
+            "    WindowAggExec: wdw=[count: Ok(Field { name: \"count\", 
data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: 
{} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), 
end_bound: CurrentRow }]",
+            "      SortExec: [non_nullable_col@1 DESC]",
+            "        MemoryExec: partitions=0, partition_sizes=[]",
+        ];
+        assert_optimized!(expected_input, expected_optimized, physical_plan);
         Ok(())
     }
 
     #[tokio::test]
     async fn test_add_required_sort() -> Result<()> {
-        let session_ctx = SessionContext::new();
-        let state = session_ctx.state();
         let schema = create_test_schema()?;
-        let source = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?)
-            as Arc<dyn ExecutionPlan>;
-        let sort_exprs = vec![PhysicalSortExpr {
-            expr: col("nullable_col", schema.as_ref()).unwrap(),
-            options: SortOptions::default(),
-        }];
-        let physical_plan = Arc::new(SortPreservingMergeExec::new(sort_exprs, 
source))
-            as Arc<dyn ExecutionPlan>;
-        let formatted = 
displayable(physical_plan.as_ref()).indent().to_string();
-        let expected = { vec!["SortPreservingMergeExec: [nullable_col@0 ASC]"] 
};
-        let actual: Vec<&str> = formatted.trim().lines().collect();
-        let actual_len = actual.len();
-        let actual_trim_last = &actual[..actual_len - 1];
-        assert_eq!(
-            expected, actual_trim_last,
-            "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
-        );
-        let optimized_physical_plan =
-            EnforceSorting::new().optimize(physical_plan, 
state.config_options())?;
-        let formatted = displayable(optimized_physical_plan.as_ref())
-            .indent()
-            .to_string();
-        let expected = {
-            vec![
-                "SortPreservingMergeExec: [nullable_col@0 ASC]",
-                "  SortExec: [nullable_col@0 ASC]",
-            ]
-        };
-        let actual: Vec<&str> = formatted.trim().lines().collect();
-        let actual_len = actual.len();
-        let actual_trim_last = &actual[..actual_len - 1];
-        assert_eq!(
-            expected, actual_trim_last,
-            "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
-        );
+        let source = memory_exec(&schema);
+
+        let sort_exprs = vec![sort_expr("nullable_col", &schema)];
+
+        let physical_plan = sort_preserving_merge_exec(sort_exprs, source);
+
+        let expected_input = vec![
+            "SortPreservingMergeExec: [nullable_col@0 ASC]",
+            "  MemoryExec: partitions=0, partition_sizes=[]",
+        ];
+        let expected_optimized = vec![
+            "SortPreservingMergeExec: [nullable_col@0 ASC]",
+            "  SortExec: [nullable_col@0 ASC]",
+            "    MemoryExec: partitions=0, partition_sizes=[]",
+        ];
+        assert_optimized!(expected_input, expected_optimized, physical_plan);
         Ok(())
     }
 
     #[tokio::test]
     async fn test_remove_unnecessary_sort1() -> Result<()> {
-        let session_ctx = SessionContext::new();
-        let state = session_ctx.state();
         let schema = create_test_schema()?;
-        let source = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?)
-            as Arc<dyn ExecutionPlan>;
-        let sort_exprs = vec![PhysicalSortExpr {
-            expr: col("nullable_col", schema.as_ref()).unwrap(),
-            options: SortOptions::default(),
-        }];
-        let sort_exec = Arc::new(SortExec::try_new(sort_exprs.clone(), source, 
None)?)
-            as Arc<dyn ExecutionPlan>;
-        let sort_preserving_merge_exec =
-            Arc::new(SortPreservingMergeExec::new(sort_exprs, sort_exec))
-                as Arc<dyn ExecutionPlan>;
-        let sort_exprs = vec![PhysicalSortExpr {
-            expr: col("nullable_col", schema.as_ref()).unwrap(),
-            options: SortOptions::default(),
-        }];
-        let sort_exec = Arc::new(SortExec::try_new(
-            sort_exprs.clone(),
-            sort_preserving_merge_exec,
-            None,
-        )?) as Arc<dyn ExecutionPlan>;
-        let sort_preserving_merge_exec =
-            Arc::new(SortPreservingMergeExec::new(sort_exprs, sort_exec))
-                as Arc<dyn ExecutionPlan>;
-        let physical_plan = sort_preserving_merge_exec;
-        let formatted = 
displayable(physical_plan.as_ref()).indent().to_string();
-        let expected = {
-            vec![
-                "SortPreservingMergeExec: [nullable_col@0 ASC]",
-                "  SortExec: [nullable_col@0 ASC]",
-                "    SortPreservingMergeExec: [nullable_col@0 ASC]",
-                "      SortExec: [nullable_col@0 ASC]",
-                "        MemoryExec: partitions=0, partition_sizes=[]",
-            ]
-        };
-        let actual: Vec<&str> = formatted.trim().lines().collect();
-        assert_eq!(
-            expected, actual,
-            "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
-        );
-        let optimized_physical_plan =
-            EnforceSorting::new().optimize(physical_plan, 
state.config_options())?;
-        let formatted = displayable(optimized_physical_plan.as_ref())
-            .indent()
-            .to_string();
-        let expected = {
-            vec![
-                "SortPreservingMergeExec: [nullable_col@0 ASC]",
-                "  SortPreservingMergeExec: [nullable_col@0 ASC]",
-                "    SortExec: [nullable_col@0 ASC]",
-                "      MemoryExec: partitions=0, partition_sizes=[]",
-            ]
-        };
-        let actual: Vec<&str> = formatted.trim().lines().collect();
-        assert_eq!(
-            expected, actual,
-            "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
-        );
+        let source = memory_exec(&schema);
+        let sort_exprs = vec![sort_expr("nullable_col", &schema)];
+        let sort = sort_exec(sort_exprs.clone(), source);
+        let spm = sort_preserving_merge_exec(sort_exprs, sort);
+
+        let sort_exprs = vec![sort_expr("nullable_col", &schema)];
+        let sort = sort_exec(sort_exprs.clone(), spm);
+        let physical_plan = sort_preserving_merge_exec(sort_exprs, sort);
+        let expected_input = vec![
+            "SortPreservingMergeExec: [nullable_col@0 ASC]",
+            "  SortExec: [nullable_col@0 ASC]",
+            "    SortPreservingMergeExec: [nullable_col@0 ASC]",
+            "      SortExec: [nullable_col@0 ASC]",
+            "        MemoryExec: partitions=0, partition_sizes=[]",
+        ];
+        let expected_optimized = vec![
+            "SortPreservingMergeExec: [nullable_col@0 ASC]",
+            "  SortPreservingMergeExec: [nullable_col@0 ASC]",
+            "    SortExec: [nullable_col@0 ASC]",
+            "      MemoryExec: partitions=0, partition_sizes=[]",
+        ];
+        assert_optimized!(expected_input, expected_optimized, physical_plan);
         Ok(())
     }
 
     #[tokio::test]
     async fn test_change_wrong_sorting() -> Result<()> {
-        let session_ctx = SessionContext::new();
-        let state = session_ctx.state();
         let schema = create_test_schema()?;
-        let source = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?)
-            as Arc<dyn ExecutionPlan>;
+        let source = memory_exec(&schema);
         let sort_exprs = vec![
-            PhysicalSortExpr {
-                expr: col("nullable_col", schema.as_ref()).unwrap(),
-                options: SortOptions::default(),
-            },
-            PhysicalSortExpr {
-                expr: col("non_nullable_col", schema.as_ref()).unwrap(),
-                options: SortOptions::default(),
-            },
+            sort_expr("nullable_col", &schema),
+            sort_expr("non_nullable_col", &schema),
         ];
-        let sort_exec = Arc::new(SortExec::try_new(
-            vec![sort_exprs[0].clone()],
-            source,
-            None,
-        )?) as Arc<dyn ExecutionPlan>;
-        let sort_preserving_merge_exec =
-            Arc::new(SortPreservingMergeExec::new(sort_exprs, sort_exec))
-                as Arc<dyn ExecutionPlan>;
-        let physical_plan = sort_preserving_merge_exec;
-        let formatted = 
displayable(physical_plan.as_ref()).indent().to_string();
-        let expected = {
-            vec![
-                "SortPreservingMergeExec: [nullable_col@0 
ASC,non_nullable_col@1 ASC]",
-                "  SortExec: [nullable_col@0 ASC]",
-                "    MemoryExec: partitions=0, partition_sizes=[]",
-            ]
-        };
-        let actual: Vec<&str> = formatted.trim().lines().collect();
-        assert_eq!(
-            expected, actual,
-            "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
-        );
-        let optimized_physical_plan =
-            EnforceSorting::new().optimize(physical_plan, 
state.config_options())?;
-        let formatted = displayable(optimized_physical_plan.as_ref())
-            .indent()
-            .to_string();
-        let expected = {
-            vec![
-                "SortPreservingMergeExec: [nullable_col@0 
ASC,non_nullable_col@1 ASC]",
-                "  SortExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]",
-                "    MemoryExec: partitions=0, partition_sizes=[]",
-            ]
-        };
-        let actual: Vec<&str> = formatted.trim().lines().collect();
-        assert_eq!(
-            expected, actual,
-            "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
-        );
+        let sort = sort_exec(vec![sort_exprs[0].clone()], source);
+        let physical_plan = sort_preserving_merge_exec(sort_exprs, sort);
+        let expected_input = vec![
+            "SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 
ASC]",
+            "  SortExec: [nullable_col@0 ASC]",
+            "    MemoryExec: partitions=0, partition_sizes=[]",
+        ];
+        let expected_optimized = vec![
+            "SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 
ASC]",
+            "  SortExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]",
+            "    MemoryExec: partitions=0, partition_sizes=[]",
+        ];
+        assert_optimized!(expected_input, expected_optimized, physical_plan);
         Ok(())
     }
+
+    /// make PhysicalSortExpr with default options
+    fn sort_expr(name: &str, schema: &Schema) -> PhysicalSortExpr {
+        sort_expr_options(name, schema, SortOptions::default())
+    }
+
+    /// PhysicalSortExpr with specified options
+    fn sort_expr_options(
+        name: &str,
+        schema: &Schema,
+        options: SortOptions,
+    ) -> PhysicalSortExpr {
+        PhysicalSortExpr {
+            expr: col(name, schema).unwrap(),
+            options,
+        }
+    }
+
+    fn memory_exec(schema: &SchemaRef) -> Arc<dyn ExecutionPlan> {
+        Arc::new(MemoryExec::try_new(&[], schema.clone(), None).unwrap())
+    }
+
+    fn sort_exec(
+        sort_exprs: impl IntoIterator<Item = PhysicalSortExpr>,
+        input: Arc<dyn ExecutionPlan>,
+    ) -> Arc<dyn ExecutionPlan> {
+        let sort_exprs = sort_exprs.into_iter().collect();
+        Arc::new(SortExec::try_new(sort_exprs, input, None).unwrap())
+    }
+
+    fn sort_preserving_merge_exec(
+        sort_exprs: impl IntoIterator<Item = PhysicalSortExpr>,
+        input: Arc<dyn ExecutionPlan>,
+    ) -> Arc<dyn ExecutionPlan> {
+        let sort_exprs = sort_exprs.into_iter().collect();
+        Arc::new(SortPreservingMergeExec::new(sort_exprs, input))
+    }
+

Review Comment:
   Maybe we can add here one more function to encapsulate window exec creation
   ```rust
       fn window_exec(
           input: Arc<dyn ExecutionPlan>,
           schema: SchemaRef,
           sort_exprs: &[PhysicalSortExpr],
           fn_arg_column_name: &str,
       ) -> Result<Arc<dyn ExecutionPlan>> {
           let fn_arg = col(fn_arg_column_name, &schema)?;
           Ok(Arc::new(WindowAggExec::try_new(
               vec![create_window_expr(
                   &WindowFunction::AggregateFunction(AggregateFunction::Count),
                   "count".to_owned(),
                   &[fn_arg],
                   &[],
                   sort_exprs,
                   Arc::new(WindowFrame::new(true)),
                   schema.as_ref(),
               )?],
               input,
               schema,
               vec![],
               Some(sort_exprs.to_vec()),
           )?) as Arc<dyn ExecutionPlan>)
       }
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to