alamb commented on code in PR #4691:
URL: https://github.com/apache/arrow-datafusion/pull/4691#discussion_r1055855083
##########
datafusion/common/src/lib.rs:
##########
@@ -63,3 +64,28 @@ macro_rules! downcast_value {
})?
}};
}
+
+/// Computes the "reverse" of given `SortOptions`.
+// TODO: If/when arrow supports `!` for `SortOptions`, we can remove this.
+pub fn reverse_sort_options(options: SortOptions) -> SortOptions {
+ SortOptions {
+ descending: !options.descending,
+ nulls_first: !options.nulls_first,
+ }
+}
+
+/// Transposes the given vector of vectors.
+pub fn transpose<T>(original: Vec<Vec<T>>) -> Vec<Vec<T>> {
Review Comment:
I suggest this function be put somewhere in physical plan or some other part
of DataFusion. I don't see a reason for it to be shared across all crates
```
/Users/alamb/Software/arrow-datafusion2/datafusion/common/src/lib.rs
78: pub fn transpose<T>(original: Vec<Vec<T>>) -> Vec<Vec<T>> {
/Users/alamb/Software/arrow-datafusion2/datafusion/core/src/physical_plan/windows/window_agg_exec.rs
40: use datafusion_common::{transpose, DataFusionError};
420: let mut columns = transpose(partition_results)
```
##########
datafusion/core/src/physical_plan/windows/window_agg_exec.rs:
##########
@@ -131,6 +135,27 @@ impl WindowAggExec {
pub fn input_schema(&self) -> SchemaRef {
self.input_schema.clone()
}
+
+ /// Get partition keys
Review Comment:
```suggestion
/// Return the overall output sort order of this window aggregate:
/// `PARTITION BY` Exprs followed by `ORDER BY` Exprs
```
##########
datafusion/core/src/physical_plan/windows/window_agg_exec.rs:
##########
@@ -131,6 +135,27 @@ impl WindowAggExec {
pub fn input_schema(&self) -> SchemaRef {
self.input_schema.clone()
}
+
+ /// Get partition keys
+ pub fn partition_by_sort_keys(&self) -> Result<Vec<PhysicalSortExpr>> {
+ let mut result = vec![];
+ // All window exprs have the same partition by, so we just use the
first one:
+ let partition_by = self.window_expr()[0].partition_by();
Review Comment:
Could this be `self.partition_keys` instead? Maybe it doesn't matter
##########
datafusion/core/tests/sql/explain_analyze.rs:
##########
@@ -61,11 +61,6 @@ async fn explain_analyze_baseline_metrics() {
"AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1]",
"metrics=[output_rows=5, elapsed_compute="
);
- assert_metrics!(
Review Comment:
nice
##########
datafusion/physical-expr/src/window/built_in_window_function_expr.rs:
##########
@@ -58,4 +58,9 @@ pub trait BuiltInWindowFunctionExpr: Send + Sync +
std::fmt::Debug {
/// Create built-in window evaluator with a batch
fn create_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>>;
+
+ /// Construct Reverse Expression
Review Comment:
```suggestion
/// Construct Reverse Expression that produces the same result
/// on a reversed window. For example `lead(10)` --> `lag(10)`
```
##########
datafusion/physical-expr/src/window/aggregate.rs:
##########
@@ -155,4 +149,17 @@ impl WindowExpr for AggregateWindowExpr {
fn get_window_frame(&self) -> &Arc<WindowFrame> {
&self.window_frame
}
+
+ fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>> {
+ if let Some(reverse_expr) = self.aggregate.reverse_expr() {
Review Comment:
This could be more functional if you wanted (see comment in
datafusion/physical-expr/src/window/built_in.rs)
##########
datafusion/core/src/physical_optimizer/utils.rs:
##########
@@ -45,3 +50,72 @@ pub fn optimize_children(
with_new_children_if_necessary(plan, children)
}
}
+
+/// Checks whether given ordering requirements are satisfied by provided
[PhysicalSortExpr]s.
+pub fn ordering_satisfy<F: FnOnce() -> EquivalenceProperties>(
+ provided: Option<&[PhysicalSortExpr]>,
+ required: Option<&[PhysicalSortExpr]>,
+ equal_properties: F,
+) -> bool {
+ match (provided, required) {
+ (_, None) => true,
+ (None, Some(_)) => false,
+ (Some(provided), Some(required)) => {
+ ordering_satisfy_concrete(provided, required, equal_properties)
+ }
+ }
+}
+
+pub fn ordering_satisfy_concrete<F: FnOnce() -> EquivalenceProperties>(
+ provided: &[PhysicalSortExpr],
+ required: &[PhysicalSortExpr],
+ equal_properties: F,
+) -> bool {
+ if required.len() > provided.len() {
+ false
+ } else if required
+ .iter()
+ .zip(provided.iter())
+ .all(|(order1, order2)| order1.eq(order2))
+ {
+ true
+ } else if let eq_classes @ [_, ..] = equal_properties().classes() {
+ let normalized_required_exprs = required
+ .iter()
+ .map(|e| {
+ normalize_sort_expr_with_equivalence_properties(e.clone(),
eq_classes)
+ })
+ .collect::<Vec<_>>();
+ let normalized_provided_exprs = provided
+ .iter()
+ .map(|e| {
+ normalize_sort_expr_with_equivalence_properties(e.clone(),
eq_classes)
+ })
+ .collect::<Vec<_>>();
+ normalized_required_exprs
+ .iter()
+ .zip(normalized_provided_exprs.iter())
+ .all(|(order1, order2)| order1.eq(order2))
+ } else {
+ false
+ }
+}
+
+/// Util function to add SortExec above child
Review Comment:
```suggestion
/// Util function to add SortExec above child
/// preserving the original partitioning
```
##########
datafusion/core/tests/sql/window.rs:
##########
@@ -1772,3 +1775,568 @@ async fn test_window_partition_by_order_by() ->
Result<()> {
);
Ok(())
}
+
+#[tokio::test]
+async fn test_window_agg_sort_reversed_plan() -> Result<()> {
+ let ctx = SessionContext::new();
+ register_aggregate_csv(&ctx).await?;
+ let sql = "SELECT
+ c9,
+ SUM(c9) OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as
sum1,
+ SUM(c9) OVER(ORDER BY c9 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as
sum2
+ FROM aggregate_test_100
+ LIMIT 5";
+
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let dataframe = ctx.sql(sql).await.expect(&msg);
+ let physical_plan = dataframe.create_physical_plan().await?;
+ let formatted = displayable(physical_plan.as_ref()).indent().to_string();
+ // Only 1 SortExec was added
+ let expected = {
+ vec![
+ "ProjectionExec: expr=[c9@2 as c9, SUM(aggregate_test_100.c9)
ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING@0 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9
DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum2]",
+ " GlobalLimitExec: skip=0, fetch=5",
+ " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field {
name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true,
dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame {
units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1))
}]",
+ " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field {
name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true,
dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame {
units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5))
}]",
+ " SortExec: [c9@0 DESC]",
Review Comment:
👍
##########
datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs:
##########
@@ -0,0 +1,887 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! RemoveUnnecessarySorts optimizer rule inspects SortExec's in the given
+//! physical plan and removes the ones it can prove unnecessary. The rule can
+//! work on valid *and* invalid physical plans with respect to sorting
+//! requirements, but always produces a valid physical plan in this sense.
+//!
+//! A non-realistic but easy to follow example: Assume that we somehow get the
fragment
+//! "SortExec: [nullable_col@0 ASC]",
+//! " SortExec: [non_nullable_col@1 ASC]",
+//! in the physical plan. The first sort is unnecessary since its result is
overwritten
+//! by another SortExec. Therefore, this rule removes it from the physical
plan.
+use crate::error::Result;
+use crate::physical_optimizer::utils::{
+ add_sort_above_child, ordering_satisfy, ordering_satisfy_concrete,
+};
+use crate::physical_optimizer::PhysicalOptimizerRule;
+use crate::physical_plan::rewrite::TreeNodeRewritable;
+use crate::physical_plan::sorts::sort::SortExec;
+use crate::physical_plan::windows::WindowAggExec;
+use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan};
+use crate::prelude::SessionConfig;
+use arrow::datatypes::SchemaRef;
+use datafusion_common::{reverse_sort_options, DataFusionError};
+use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
+use itertools::izip;
+use std::iter::zip;
+use std::sync::Arc;
+
+/// This rule inspects SortExec's in the given physical plan and removes the
+/// ones it can prove unnecessary.
+#[derive(Default)]
+pub struct RemoveUnnecessarySorts {}
+
+impl RemoveUnnecessarySorts {
+ #[allow(missing_docs)]
+ pub fn new() -> Self {
+ Self {}
+ }
+}
+
+/// This is a "data class" we use within the [RemoveUnnecessarySorts] rule
+/// that tracks the closest `SortExec` descendant for every child of a plan.
+#[derive(Debug, Clone)]
+struct PlanWithCorrespondingSort {
+ plan: Arc<dyn ExecutionPlan>,
+ // For every child, keep a vector of `ExecutionPlan`s starting from the
+ // closest `SortExec` till the current plan. The first index of the tuple
is
+ // the child index of the plan -- we need this information as we make
updates.
+ sort_onwards: Vec<Vec<(usize, Arc<dyn ExecutionPlan>)>>,
+}
+
+impl PlanWithCorrespondingSort {
+ pub fn new(plan: Arc<dyn ExecutionPlan>) -> Self {
+ let length = plan.children().len();
+ PlanWithCorrespondingSort {
+ plan,
+ sort_onwards: vec![vec![]; length],
+ }
+ }
+
+ pub fn children(&self) -> Vec<PlanWithCorrespondingSort> {
+ self.plan
+ .children()
+ .into_iter()
+ .map(|child| PlanWithCorrespondingSort::new(child))
+ .collect()
+ }
+}
+
+impl TreeNodeRewritable for PlanWithCorrespondingSort {
+ fn map_children<F>(self, transform: F) -> Result<Self>
+ where
+ F: FnMut(Self) -> Result<Self>,
+ {
+ let children = self.children();
+ if children.is_empty() {
+ Ok(self)
+ } else {
+ let children_requirements = children
Review Comment:
I was expecting to see a check like the following in this module
```rust
// if there is a sort node child
if node.child.is_sort() {
let sort = node.child;
// if the sort's input ordering already satisfies the output
// required ordering, it can be removed
if ordering_satisfy_concrete(
sort.child.output_ordering()
node.required_input_ordering()
) {
// remove sort and connect child directly to node
node.child = sort.child
}
```
##########
datafusion/core/tests/sql/window.rs:
##########
@@ -1772,3 +1775,568 @@ async fn test_window_partition_by_order_by() ->
Result<()> {
);
Ok(())
}
+
+#[tokio::test]
+async fn test_window_agg_sort_reversed_plan() -> Result<()> {
+ let ctx = SessionContext::new();
+ register_aggregate_csv(&ctx).await?;
+ let sql = "SELECT
+ c9,
+ SUM(c9) OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as
sum1,
+ SUM(c9) OVER(ORDER BY c9 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as
sum2
+ FROM aggregate_test_100
+ LIMIT 5";
+
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let dataframe = ctx.sql(sql).await.expect(&msg);
+ let physical_plan = dataframe.create_physical_plan().await?;
+ let formatted = displayable(physical_plan.as_ref()).indent().to_string();
+ // Only 1 SortExec was added
+ let expected = {
+ vec![
+ "ProjectionExec: expr=[c9@2 as c9, SUM(aggregate_test_100.c9)
ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING@0 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9
DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum2]",
+ " GlobalLimitExec: skip=0, fetch=5",
+ " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field {
name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true,
dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame {
units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1))
}]",
+ " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field {
name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true,
dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame {
units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5))
}]",
+ " SortExec: [c9@0 DESC]",
+ ]
+ };
+
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ let actual_len = actual.len();
+ let actual_trim_last = &actual[..actual_len - 1];
+ assert_eq!(
+ expected, actual_trim_last,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+------------+-------------+-------------+",
+ "| c9 | sum1 | sum2 |",
+ "+------------+-------------+-------------+",
+ "| 4268716378 | 8498370520 | 24997484146 |",
+ "| 4229654142 | 12714811027 | 29012926487 |",
+ "| 4216440507 | 16858984380 | 28743001064 |",
+ "| 4144173353 | 20935849039 | 28472563256 |",
+ "| 4076864659 | 24997484146 | 28118515915 |",
+ "+------------+-------------+-------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_window_agg_sort_reversed_plan_builtin() -> Result<()> {
+ let ctx = SessionContext::new();
+ register_aggregate_csv(&ctx).await?;
+ let sql = "SELECT
+ c9,
+ FIRST_VALUE(c9) OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING) as fv1,
+ FIRST_VALUE(c9) OVER(ORDER BY c9 DESC ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING) as fv2,
+ LAG(c9, 2, 10101) OVER(ORDER BY c9 ASC) as lag1,
+ LAG(c9, 2, 10101) OVER(ORDER BY c9 DESC ROWS BETWEEN 10 PRECEDING and 1
FOLLOWING) as lag2,
+ LEAD(c9, 2, 10101) OVER(ORDER BY c9 ASC) as lead1,
+ LEAD(c9, 2, 10101) OVER(ORDER BY c9 DESC ROWS BETWEEN 10 PRECEDING and 1
FOLLOWING) as lead2
+ FROM aggregate_test_100
+ LIMIT 5";
+
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let dataframe = ctx.sql(sql).await.expect(&msg);
+ let physical_plan = dataframe.create_physical_plan().await?;
+ let formatted = displayable(physical_plan.as_ref()).indent().to_string();
+ // Only 1 SortExec was added
+ let expected = {
+ vec![
+ "ProjectionExec: expr=[c9@6 as c9,
FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS
LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as fv1,
FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS
FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@3 as fv2,
LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY
[aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND
CURRENT ROW@1 as lag1, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER
BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1
FOLLOWING@4 as lag2, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY
[aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND
CURRENT ROW@2 as lead1, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER
BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1
FOLLOWING@5 as lead2]",
+ " GlobalLimitExec: skip=0, fetch=5",
+ " WindowAggExec: wdw=[FIRST_VALUE(aggregate_test_100.c9):
Ok(Field { name: \"FIRST_VALUE(aggregate_test_100.c9)\", data_type: UInt32,
nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame:
WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound:
Following(UInt64(1)) }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)):
Ok(Field { name: \"LAG(aggregate_test_100.c9,Int64(2),Int64(10101))\",
data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false,
metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow,
end_bound: Following(UInt32(NULL)) },
LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name:
\"LEAD(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32,
nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame:
WindowFrame { units: Range, start_bound: CurrentRow, end_bound:
Following(UInt32(NULL)) }]",
+ " WindowAggExec: wdw=[FIRST_VALUE(aggregate_test_100.c9):
Ok(Field { name: \"FIRST_VALUE(aggregate_test_100.c9)\", data_type: UInt32,
nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame:
WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound:
Following(UInt64(5)) }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)):
Ok(Field { name: \"LAG(aggregate_test_100.c9,Int64(2),Int64(10101))\",
data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false,
metadata: {} }), frame: WindowFrame { units: Rows, start_bound:
Preceding(UInt64(10)), end_bound: Following(UInt64(1)) },
LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name:
\"LEAD(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32,
nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame:
WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound:
Following(UInt64(1)) }]",
+ " SortExec: [c9@0 DESC]",
+ ]
+ };
+
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ let actual_len = actual.len();
+ let actual_trim_last = &actual[..actual_len - 1];
+ assert_eq!(
+ expected, actual_trim_last,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+
"+------------+------------+------------+------------+------------+------------+------------+",
+ "| c9 | fv1 | fv2 | lag1 | lag2 |
lead1 | lead2 |",
+
"+------------+------------+------------+------------+------------+------------+------------+",
+ "| 4268716378 | 4229654142 | 4268716378 | 4216440507 | 10101 |
10101 | 4216440507 |",
+ "| 4229654142 | 4216440507 | 4268716378 | 4144173353 | 10101 |
10101 | 4144173353 |",
+ "| 4216440507 | 4144173353 | 4229654142 | 4076864659 | 4268716378 |
4268716378 | 4076864659 |",
+ "| 4144173353 | 4076864659 | 4216440507 | 4061635107 | 4229654142 |
4229654142 | 4061635107 |",
+ "| 4076864659 | 4061635107 | 4144173353 | 4015442341 | 4216440507 |
4216440507 | 4015442341 |",
+
"+------------+------------+------------+------------+------------+------------+------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_window_agg_sort_non_reversed_plan() -> Result<()> {
+ let ctx = SessionContext::new();
+ register_aggregate_csv(&ctx).await?;
+ let sql = "SELECT
+ c9,
+ ROW_NUMBER() OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING) as rn1,
+ ROW_NUMBER() OVER(ORDER BY c9 DESC ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING) as rn2
+ FROM aggregate_test_100
+ LIMIT 5";
+
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let dataframe = ctx.sql(sql).await.expect(&msg);
+ let physical_plan = dataframe.create_physical_plan().await?;
+ let formatted = displayable(physical_plan.as_ref()).indent().to_string();
+ // We cannot reverse each window function (ROW_NUMBER is not reversible)
+ let expected = {
+ vec![
+ "ProjectionExec: expr=[c9@2 as c9, ROW_NUMBER() ORDER BY
[aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING@0 as rn1, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS
FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as rn2]",
+ " GlobalLimitExec: skip=0, fetch=5",
+ " WindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name:
\"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0,
dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows,
start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]",
+ " SortExec: [c9@1 ASC NULLS LAST]",
+ " WindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name:
\"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0,
dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows,
start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]",
+ " SortExec: [c9@0 DESC]",
+ ]
+ };
+
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ let actual_len = actual.len();
+ let actual_trim_last = &actual[..actual_len - 1];
+ assert_eq!(
+ expected, actual_trim_last,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+-----------+-----+-----+",
+ "| c9 | rn1 | rn2 |",
+ "+-----------+-----+-----+",
+ "| 28774375 | 1 | 100 |",
+ "| 63044568 | 2 | 99 |",
+ "| 141047417 | 3 | 98 |",
+ "| 141680161 | 4 | 97 |",
+ "| 145294611 | 5 | 96 |",
+ "+-----------+-----+-----+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_window_agg_sort_multi_layer_non_reversed_plan() -> Result<()> {
+ let ctx = SessionContext::new();
+ register_aggregate_csv(&ctx).await?;
+ let sql = "SELECT
+ c9,
+ SUM(c9) OVER(ORDER BY c9 ASC, c1 ASC, c2 ASC ROWS BETWEEN 1 PRECEDING AND
5 FOLLOWING) as sum1,
+ SUM(c9) OVER(ORDER BY c9 DESC, c1 DESC ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING) as sum2,
+ ROW_NUMBER() OVER(ORDER BY c9 DESC ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING) as rn2
+ FROM aggregate_test_100
+ LIMIT 5";
+
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let dataframe = ctx.sql(sql).await.expect(&msg);
+ let physical_plan = dataframe.create_physical_plan().await?;
+ let formatted = displayable(physical_plan.as_ref()).indent().to_string();
+ // We cannot reverse each window function (ROW_NUMBER is not reversible)
Review Comment:
👍
##########
datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs:
##########
@@ -0,0 +1,887 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! RemoveUnnecessarySorts optimizer rule inspects SortExec's in the given
+//! physical plan and removes the ones it can prove unnecessary. The rule can
+//! work on valid *and* invalid physical plans with respect to sorting
+//! requirements, but always produces a valid physical plan in this sense.
+//!
+//! A non-realistic but easy to follow example: Assume that we somehow get the
fragment
+//! "SortExec: [nullable_col@0 ASC]",
+//! " SortExec: [non_nullable_col@1 ASC]",
+//! in the physical plan. The first sort is unnecessary since its result is
overwritten
+//! by another SortExec. Therefore, this rule removes it from the physical
plan.
+use crate::error::Result;
+use crate::physical_optimizer::utils::{
+ add_sort_above_child, ordering_satisfy, ordering_satisfy_concrete,
+};
+use crate::physical_optimizer::PhysicalOptimizerRule;
+use crate::physical_plan::rewrite::TreeNodeRewritable;
+use crate::physical_plan::sorts::sort::SortExec;
+use crate::physical_plan::windows::WindowAggExec;
+use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan};
+use crate::prelude::SessionConfig;
+use arrow::datatypes::SchemaRef;
+use datafusion_common::{reverse_sort_options, DataFusionError};
+use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
+use itertools::izip;
+use std::iter::zip;
+use std::sync::Arc;
+
+/// This rule inspects SortExec's in the given physical plan and removes the
+/// ones it can prove unnecessary.
+#[derive(Default)]
+pub struct RemoveUnnecessarySorts {}
+
+impl RemoveUnnecessarySorts {
+ #[allow(missing_docs)]
+ pub fn new() -> Self {
+ Self {}
+ }
+}
+
+/// This is a "data class" we use within the [RemoveUnnecessarySorts] rule
+/// that tracks the closest `SortExec` descendant for every child of a plan.
+#[derive(Debug, Clone)]
+struct PlanWithCorrespondingSort {
+ plan: Arc<dyn ExecutionPlan>,
+ // For every child, keep a vector of `ExecutionPlan`s starting from the
+ // closest `SortExec` till the current plan. The first index of the tuple
is
+ // the child index of the plan -- we need this information as we make
updates.
+ sort_onwards: Vec<Vec<(usize, Arc<dyn ExecutionPlan>)>>,
+}
+
+impl PlanWithCorrespondingSort {
+ pub fn new(plan: Arc<dyn ExecutionPlan>) -> Self {
+ let length = plan.children().len();
+ PlanWithCorrespondingSort {
+ plan,
+ sort_onwards: vec![vec![]; length],
+ }
+ }
+
+ pub fn children(&self) -> Vec<PlanWithCorrespondingSort> {
+ self.plan
+ .children()
+ .into_iter()
+ .map(|child| PlanWithCorrespondingSort::new(child))
+ .collect()
+ }
+}
+
+impl TreeNodeRewritable for PlanWithCorrespondingSort {
+ fn map_children<F>(self, transform: F) -> Result<Self>
+ where
+ F: FnMut(Self) -> Result<Self>,
+ {
+ let children = self.children();
+ if children.is_empty() {
+ Ok(self)
+ } else {
+ let children_requirements = children
+ .into_iter()
+ .map(transform)
+ .collect::<Result<Vec<_>>>()?;
+ let children_plans = children_requirements
+ .iter()
+ .map(|elem| elem.plan.clone())
+ .collect::<Vec<_>>();
+ let sort_onwards = children_requirements
+ .iter()
+ .map(|item| {
+ if item.sort_onwards.is_empty() {
+ vec![]
+ } else {
+ // TODO: When `maintains_input_order` returns
Vec<bool>,
+ // pass the order-enforcing sort upwards.
+ item.sort_onwards[0].clone()
+ }
+ })
+ .collect::<Vec<_>>();
+ let plan = with_new_children_if_necessary(self.plan,
children_plans)?;
+ Ok(PlanWithCorrespondingSort { plan, sort_onwards })
+ }
+ }
+}
+
+impl PhysicalOptimizerRule for RemoveUnnecessarySorts {
+ fn optimize(
+ &self,
+ plan: Arc<dyn ExecutionPlan>,
+ _config: &SessionConfig,
+ ) -> Result<Arc<dyn ExecutionPlan>> {
+ // Execute a post-order traversal to adjust input key ordering:
+ let plan_requirements = PlanWithCorrespondingSort::new(plan);
+ let adjusted =
plan_requirements.transform_up(&remove_unnecessary_sorts)?;
+ Ok(adjusted.plan)
+ }
+
+ fn name(&self) -> &str {
+ "RemoveUnnecessarySorts"
+ }
+
+ fn schema_check(&self) -> bool {
+ true
+ }
+}
+
+fn remove_unnecessary_sorts(
+ requirements: PlanWithCorrespondingSort,
+) -> Result<Option<PlanWithCorrespondingSort>> {
+ // Perform naive analysis at the beginning -- remove already-satisfied
sorts:
+ if let Some(result) = analyze_immediate_sort_removal(&requirements)? {
+ return Ok(Some(result));
+ }
+ let plan = &requirements.plan;
+ let mut new_children = plan.children().clone();
+ let mut new_onwards = requirements.sort_onwards.clone();
+ for (idx, (child, sort_onwards, required_ordering)) in izip!(
+ new_children.iter_mut(),
+ new_onwards.iter_mut(),
+ plan.required_input_ordering()
+ )
+ .enumerate()
+ {
+ let physical_ordering = child.output_ordering();
+ match (required_ordering, physical_ordering) {
+ (Some(required_ordering), Some(physical_ordering)) => {
+ let is_ordering_satisfied = ordering_satisfy_concrete(
+ physical_ordering,
+ required_ordering,
+ || child.equivalence_properties(),
+ );
+ if !is_ordering_satisfied {
+ // Make sure we preserve the ordering requirements:
+ update_child_to_remove_unnecessary_sort(child,
sort_onwards)?;
+ let sort_expr = required_ordering.to_vec();
+ *child = add_sort_above_child(child, sort_expr)?;
+ sort_onwards.push((idx, child.clone()))
+ } else if let [first, ..] = sort_onwards.as_slice() {
+ // The ordering requirement is met, we can analyze if
there is an unnecessary sort:
+ let sort_any = first.1.clone();
+ let sort_exec = convert_to_sort_exec(&sort_any)?;
+ let sort_output_ordering = sort_exec.output_ordering();
+ let sort_input_ordering =
sort_exec.input().output_ordering();
+ // Simple analysis: Does the input of the sort in question
already satisfy the ordering requirements?
+ if ordering_satisfy(sort_input_ordering,
sort_output_ordering, || {
+ sort_exec.input().equivalence_properties()
+ }) {
+ update_child_to_remove_unnecessary_sort(child,
sort_onwards)?;
+ } else if let Some(window_agg_exec) =
+
requirements.plan.as_any().downcast_ref::<WindowAggExec>()
+ {
+ // For window expressions, we can remove some sorts
when we can
+ // calculate the result in reverse:
+ if let Some(res) = analyze_window_sort_removal(
+ window_agg_exec,
+ sort_exec,
+ sort_onwards,
+ )? {
+ return Ok(Some(res));
+ }
+ }
+ // TODO: Once we can ensure that required ordering
information propagates with
+ // necessary lineage information, compare
`sort_input_ordering` and `required_ordering`.
+ // This will enable us to handle cases such as (a,b)
-> Sort -> (a,b,c) -> Required(a,b).
+ // Currently, we can not remove such sorts.
+ }
+ }
+ (Some(required), None) => {
+ // Ordering requirement is not met, we should add a SortExec
to the plan.
+ let sort_expr = required.to_vec();
+ *child = add_sort_above_child(child, sort_expr)?;
+ *sort_onwards = vec![(idx, child.clone())];
+ }
+ (None, Some(_)) => {
+ // We have a SortExec whose effect may be neutralized by a
order-imposing
+ // operator. In this case, remove this sort:
+ if !requirements.plan.maintains_input_order() {
+ update_child_to_remove_unnecessary_sort(child,
sort_onwards)?;
+ }
+ }
+ (None, None) => {}
+ }
+ }
+ if plan.children().is_empty() {
+ Ok(Some(requirements))
+ } else {
+ let new_plan = requirements.plan.with_new_children(new_children)?;
+ for (idx, (trace, required_ordering)) in new_onwards
+ .iter_mut()
+ .zip(new_plan.required_input_ordering())
+ .enumerate()
+ .take(new_plan.children().len())
+ {
+ // TODO: When `maintains_input_order` returns a `Vec<bool>`, use
corresponding index.
+ if new_plan.maintains_input_order()
+ && required_ordering.is_none()
+ && !trace.is_empty()
+ {
+ trace.push((idx, new_plan.clone()));
+ } else {
+ trace.clear();
+ if new_plan.as_any().is::<SortExec>() {
+ trace.push((idx, new_plan.clone()));
+ }
+ }
+ }
+ Ok(Some(PlanWithCorrespondingSort {
+ plan: new_plan,
+ sort_onwards: new_onwards,
+ }))
+ }
+}
+
+/// Analyzes a given `SortExec` to determine whether its input already has
+/// a finer ordering than this `SortExec` enforces.
+fn analyze_immediate_sort_removal(
+ requirements: &PlanWithCorrespondingSort,
+) -> Result<Option<PlanWithCorrespondingSort>> {
+ if let Some(sort_exec) =
requirements.plan.as_any().downcast_ref::<SortExec>() {
+ // If this sort is unnecessary, we should remove it:
+ if ordering_satisfy(
+ sort_exec.input().output_ordering(),
+ sort_exec.output_ordering(),
+ || sort_exec.input().equivalence_properties(),
+ ) {
+ // Since we know that a `SortExec` has exactly one child,
+ // we can use the zero index safely:
+ let mut new_onwards = requirements.sort_onwards[0].to_vec();
+ if !new_onwards.is_empty() {
+ new_onwards.pop();
+ }
+ return Ok(Some(PlanWithCorrespondingSort {
+ plan: sort_exec.input().clone(),
+ sort_onwards: vec![new_onwards],
+ }));
+ }
+ }
+ Ok(None)
+}
+
+/// Analyzes a `WindowAggExec` to determine whether it may allow removing a
sort.
+fn analyze_window_sort_removal(
+ window_agg_exec: &WindowAggExec,
+ sort_exec: &SortExec,
+ sort_onward: &mut Vec<(usize, Arc<dyn ExecutionPlan>)>,
+) -> Result<Option<PlanWithCorrespondingSort>> {
+ let required_ordering = sort_exec.output_ordering().ok_or_else(|| {
+ DataFusionError::Plan("A SortExec should have output
ordering".to_string())
+ })?;
+ let physical_ordering = sort_exec.input().output_ordering();
+ let physical_ordering = if let Some(physical_ordering) = physical_ordering
{
+ physical_ordering
+ } else {
+ // If there is no physical ordering, there is no way to remove a sort
-- immediately return:
+ return Ok(None);
+ };
+ let window_expr = window_agg_exec.window_expr();
+ let (can_skip_sorting, should_reverse) = can_skip_sort(
+ window_expr[0].partition_by(),
+ required_ordering,
+ &sort_exec.input().schema(),
+ physical_ordering,
+ )?;
+ if can_skip_sorting {
+ let new_window_expr = if should_reverse {
+ window_expr
+ .iter()
+ .map(|e| e.get_reverse_expr())
+ .collect::<Option<Vec<_>>>()
+ } else {
+ Some(window_expr.to_vec())
+ };
+ if let Some(window_expr) = new_window_expr {
+ let new_child =
remove_corresponding_sort_from_sub_plan(sort_onward)?;
+ let new_schema = new_child.schema();
+ let new_plan = Arc::new(WindowAggExec::try_new(
+ window_expr,
+ new_child,
+ new_schema,
+ window_agg_exec.partition_keys.clone(),
+ Some(physical_ordering.to_vec()),
+ )?);
+ return Ok(Some(PlanWithCorrespondingSort::new(new_plan)));
+ }
+ }
+ Ok(None)
+}
+
+/// Updates child to remove the unnecessary sorting below it.
+fn update_child_to_remove_unnecessary_sort(
+ child: &mut Arc<dyn ExecutionPlan>,
+ sort_onwards: &mut Vec<(usize, Arc<dyn ExecutionPlan>)>,
+) -> Result<()> {
+ if !sort_onwards.is_empty() {
+ *child = remove_corresponding_sort_from_sub_plan(sort_onwards)?;
+ }
+ Ok(())
+}
+
+/// Converts an [ExecutionPlan] trait object to a [SortExec] when possible.
+fn convert_to_sort_exec(sort_any: &Arc<dyn ExecutionPlan>) ->
Result<&SortExec> {
+ sort_any.as_any().downcast_ref::<SortExec>().ok_or_else(|| {
+ DataFusionError::Plan("Given ExecutionPlan is not a
SortExec".to_string())
+ })
+}
+
+/// Removes the sort from the plan in `sort_onwards`.
+fn remove_corresponding_sort_from_sub_plan(
+ sort_onwards: &mut Vec<(usize, Arc<dyn ExecutionPlan>)>,
+) -> Result<Arc<dyn ExecutionPlan>> {
+ let (sort_child_idx, sort_any) = sort_onwards[0].clone();
+ let sort_exec = convert_to_sort_exec(&sort_any)?;
+ let mut prev_layer = sort_exec.input().clone();
+ let mut prev_child_idx = sort_child_idx;
+ // In the loop below, se start from 1 as the first one is a SortExec
+ // and we are removing it from the plan.
+ for (child_idx, layer) in sort_onwards.iter().skip(1) {
+ let mut children = layer.children();
+ children[prev_child_idx] = prev_layer;
+ prev_layer = layer.clone().with_new_children(children)?;
+ prev_child_idx = *child_idx;
+ }
+ // We have removed the sort, hence empty the sort_onwards:
+ sort_onwards.clear();
+ Ok(prev_layer)
+}
+
+#[derive(Debug)]
+/// This structure stores extra column information required to remove
unnecessary sorts.
+pub struct ColumnInfo {
+ is_aligned: bool,
+ reverse: bool,
+ is_partition: bool,
+}
+
+/// Compares physical ordering and required ordering of all
`PhysicalSortExpr`s and returns a tuple.
+/// The first element indicates whether these `PhysicalSortExpr`s can be
removed from the physical plan.
+/// The second element is a flag indicating whether we should reverse the sort
direction in order to
+/// remove physical sort expressions from the plan.
+pub fn can_skip_sort(
+ partition_keys: &[Arc<dyn PhysicalExpr>],
+ required: &[PhysicalSortExpr],
+ input_schema: &SchemaRef,
+ physical_ordering: &[PhysicalSortExpr],
+) -> Result<(bool, bool)> {
+ if required.len() > physical_ordering.len() {
+ return Ok((false, false));
+ }
+ let mut col_infos = vec![];
+ for (sort_expr, physical_expr) in zip(required, physical_ordering) {
+ let column = sort_expr.expr.clone();
+ let is_partition = partition_keys.iter().any(|e| e.eq(&column));
+ let (is_aligned, reverse) =
+ check_alignment(input_schema, physical_expr, sort_expr);
+ col_infos.push(ColumnInfo {
+ is_aligned,
+ reverse,
+ is_partition,
+ });
+ }
+ let partition_by_sections = col_infos
+ .iter()
+ .filter(|elem| elem.is_partition)
+ .collect::<Vec<_>>();
+ let can_skip_partition_bys = if partition_by_sections.is_empty() {
+ true
+ } else {
+ let first_reverse = partition_by_sections[0].reverse;
+ let can_skip_partition_bys = partition_by_sections
+ .iter()
+ .all(|c| c.is_aligned && c.reverse == first_reverse);
+ can_skip_partition_bys
+ };
+ let order_by_sections = col_infos
+ .iter()
+ .filter(|elem| !elem.is_partition)
+ .collect::<Vec<_>>();
+ let (can_skip_order_bys, should_reverse_order_bys) = if
order_by_sections.is_empty() {
+ (true, false)
+ } else {
+ let first_reverse = order_by_sections[0].reverse;
+ let can_skip_order_bys = order_by_sections
+ .iter()
+ .all(|c| c.is_aligned && c.reverse == first_reverse);
+ (can_skip_order_bys, first_reverse)
+ };
+ let can_skip = can_skip_order_bys && can_skip_partition_bys;
+ Ok((can_skip, should_reverse_order_bys))
+}
+
+/// Compares `physical_ordering` and `required` ordering, returns a tuple
+/// indicating (1) whether this column requires sorting, and (2) whether we
+/// should reverse the window expression in order to avoid sorting.
+fn check_alignment(
+ input_schema: &SchemaRef,
+ physical_ordering: &PhysicalSortExpr,
+ required: &PhysicalSortExpr,
+) -> (bool, bool) {
+ if required.expr.eq(&physical_ordering.expr) {
+ let nullable = required.expr.nullable(input_schema).unwrap();
+ let physical_opts = physical_ordering.options;
+ let required_opts = required.options;
+ let is_reversed = if nullable {
+ physical_opts == reverse_sort_options(required_opts)
+ } else {
+ // If the column is not nullable, NULLS FIRST/LAST is not
important.
+ physical_opts.descending != required_opts.descending
+ };
+ let can_skip = !nullable || is_reversed || (physical_opts ==
required_opts);
+ (can_skip, is_reversed)
+ } else {
+ (false, false)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::physical_plan::displayable;
+ use crate::physical_plan::filter::FilterExec;
+ use crate::physical_plan::memory::MemoryExec;
+ use
crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
+ use crate::physical_plan::windows::create_window_expr;
+ use crate::prelude::SessionContext;
+ use arrow::compute::SortOptions;
+ use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
+ use datafusion_common::Result;
+ use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunction};
+ use datafusion_physical_expr::expressions::{col, NotExpr};
+ use datafusion_physical_expr::PhysicalSortExpr;
+ use std::sync::Arc;
+
+ fn create_test_schema() -> Result<SchemaRef> {
+ let nullable_column = Field::new("nullable_col", DataType::Int32,
true);
+ let non_nullable_column = Field::new("non_nullable_col",
DataType::Int32, false);
+ let schema = Arc::new(Schema::new(vec![nullable_column,
non_nullable_column]));
+
+ Ok(schema)
+ }
+
+ #[tokio::test]
+ async fn test_is_column_aligned_nullable() -> Result<()> {
+ let schema = create_test_schema()?;
+ let params = vec![
+ ((true, true), (false, false), (true, true)),
+ ((true, true), (false, true), (false, false)),
+ ((true, true), (true, false), (false, false)),
+ ((true, false), (false, true), (true, true)),
+ ((true, false), (false, false), (false, false)),
+ ((true, false), (true, true), (false, false)),
+ ];
+ for (
+ (physical_desc, physical_nulls_first),
+ (req_desc, req_nulls_first),
+ (is_aligned_expected, reverse_expected),
+ ) in params
+ {
+ let physical_ordering = PhysicalSortExpr {
+ expr: col("nullable_col", &schema)?,
+ options: SortOptions {
+ descending: physical_desc,
+ nulls_first: physical_nulls_first,
+ },
+ };
+ let required_ordering = PhysicalSortExpr {
+ expr: col("nullable_col", &schema)?,
+ options: SortOptions {
+ descending: req_desc,
+ nulls_first: req_nulls_first,
+ },
+ };
+ let (is_aligned, reverse) =
+ check_alignment(&schema, &physical_ordering,
&required_ordering);
+ assert_eq!(is_aligned, is_aligned_expected);
+ assert_eq!(reverse, reverse_expected);
+ }
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_is_column_aligned_non_nullable() -> Result<()> {
+ let schema = create_test_schema()?;
+
+ let params = vec![
+ ((true, true), (false, false), (true, true)),
+ ((true, true), (false, true), (true, true)),
+ ((true, true), (true, false), (true, false)),
+ ((true, false), (false, true), (true, true)),
+ ((true, false), (false, false), (true, true)),
+ ((true, false), (true, true), (true, false)),
+ ];
+ for (
+ (physical_desc, physical_nulls_first),
+ (req_desc, req_nulls_first),
+ (is_aligned_expected, reverse_expected),
+ ) in params
+ {
+ let physical_ordering = PhysicalSortExpr {
+ expr: col("non_nullable_col", &schema)?,
+ options: SortOptions {
+ descending: physical_desc,
+ nulls_first: physical_nulls_first,
+ },
+ };
+ let required_ordering = PhysicalSortExpr {
+ expr: col("non_nullable_col", &schema)?,
+ options: SortOptions {
+ descending: req_desc,
+ nulls_first: req_nulls_first,
+ },
+ };
+ let (is_aligned, reverse) =
+ check_alignment(&schema, &physical_ordering,
&required_ordering);
+ assert_eq!(is_aligned, is_aligned_expected);
+ assert_eq!(reverse, reverse_expected);
+ }
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_remove_unnecessary_sort() -> Result<()> {
+ let session_ctx = SessionContext::new();
+ let conf = session_ctx.copied_config();
+ let schema = create_test_schema()?;
+ let source = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?)
+ as Arc<dyn ExecutionPlan>;
+ let sort_exprs = vec![PhysicalSortExpr {
+ expr: col("non_nullable_col", schema.as_ref()).unwrap(),
+ options: SortOptions::default(),
+ }];
+ let sort_exec = Arc::new(SortExec::try_new(sort_exprs, source, None)?)
+ as Arc<dyn ExecutionPlan>;
+ let sort_exprs = vec![PhysicalSortExpr {
+ expr: col("nullable_col", schema.as_ref()).unwrap(),
+ options: SortOptions::default(),
+ }];
+ let physical_plan = Arc::new(SortExec::try_new(sort_exprs, sort_exec,
None)?)
+ as Arc<dyn ExecutionPlan>;
+ let formatted =
displayable(physical_plan.as_ref()).indent().to_string();
+ let expected = {
+ vec![
+ "SortExec: [nullable_col@0 ASC]",
+ " SortExec: [non_nullable_col@1 ASC]",
+ ]
+ };
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ let actual_len = actual.len();
+ let actual_trim_last = &actual[..actual_len - 1];
+ assert_eq!(
+ expected, actual_trim_last,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+ let optimized_physical_plan =
+ RemoveUnnecessarySorts::new().optimize(physical_plan, &conf)?;
+ let formatted = displayable(optimized_physical_plan.as_ref())
+ .indent()
+ .to_string();
+ let expected = { vec!["SortExec: [nullable_col@0 ASC]"] };
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ let actual_len = actual.len();
+ let actual_trim_last = &actual[..actual_len - 1];
+ assert_eq!(
+ expected, actual_trim_last,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_remove_unnecessary_sort_window_multilayer() -> Result<()> {
+ let session_ctx = SessionContext::new();
+ let conf = session_ctx.copied_config();
+ let schema = create_test_schema()?;
+ let source = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?)
+ as Arc<dyn ExecutionPlan>;
+ let sort_exprs = vec![PhysicalSortExpr {
+ expr: col("non_nullable_col", source.schema().as_ref()).unwrap(),
+ options: SortOptions {
+ descending: true,
+ nulls_first: true,
+ },
+ }];
+ let sort_exec = Arc::new(SortExec::try_new(sort_exprs.clone(), source,
None)?)
+ as Arc<dyn ExecutionPlan>;
+ let window_agg_exec = Arc::new(WindowAggExec::try_new(
+ vec![create_window_expr(
+ &WindowFunction::AggregateFunction(AggregateFunction::Count),
+ "count".to_owned(),
+ &[col("non_nullable_col", &schema)?],
+ &[],
+ &sort_exprs,
+ Arc::new(WindowFrame::new(true)),
+ schema.as_ref(),
+ )?],
+ sort_exec.clone(),
+ sort_exec.schema(),
+ vec![],
+ Some(sort_exprs),
+ )?) as Arc<dyn ExecutionPlan>;
+ let sort_exprs = vec![PhysicalSortExpr {
+ expr: col("non_nullable_col",
window_agg_exec.schema().as_ref()).unwrap(),
+ options: SortOptions {
+ descending: false,
+ nulls_first: false,
+ },
+ }];
+ let sort_exec = Arc::new(SortExec::try_new(
+ sort_exprs.clone(),
+ window_agg_exec,
+ None,
+ )?) as Arc<dyn ExecutionPlan>;
+ // Add dummy layer propagating Sort above, to test whether sort can be
removed from multi layer before
+ let filter_exec = Arc::new(FilterExec::try_new(
+ Arc::new(NotExpr::new(
+ col("non_nullable_col", schema.as_ref()).unwrap(),
+ )),
+ sort_exec,
+ )?) as Arc<dyn ExecutionPlan>;
+ // let filter_exec = sort_exec;
+ let window_agg_exec = Arc::new(WindowAggExec::try_new(
+ vec![create_window_expr(
+ &WindowFunction::AggregateFunction(AggregateFunction::Count),
+ "count".to_owned(),
+ &[col("non_nullable_col", &schema)?],
+ &[],
+ &sort_exprs,
+ Arc::new(WindowFrame::new(true)),
+ schema.as_ref(),
+ )?],
+ filter_exec.clone(),
+ filter_exec.schema(),
+ vec![],
+ Some(sort_exprs),
+ )?) as Arc<dyn ExecutionPlan>;
+ let physical_plan = window_agg_exec;
+ let formatted =
displayable(physical_plan.as_ref()).indent().to_string();
+ let expected = {
+ vec![
+ "WindowAggExec: wdw=[count: Ok(Field { name: \"count\",
data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata:
{} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL),
end_bound: CurrentRow }]",
+ " FilterExec: NOT non_nullable_col@1",
+ " SortExec: [non_nullable_col@2 ASC NULLS LAST]",
+ " WindowAggExec: wdw=[count: Ok(Field { name: \"count\",
data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata:
{} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL),
end_bound: CurrentRow }]",
+ " SortExec: [non_nullable_col@1 DESC]",
+ " MemoryExec: partitions=0, partition_sizes=[]",
+ ]
+ };
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+ let optimized_physical_plan =
+ RemoveUnnecessarySorts::new().optimize(physical_plan, &conf)?;
+ let formatted = displayable(optimized_physical_plan.as_ref())
+ .indent()
+ .to_string();
+ let expected = {
+ vec![
+ "WindowAggExec: wdw=[count: Ok(Field { name: \"count\",
data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata:
{} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound:
Following(NULL) }]",
+ " FilterExec: NOT non_nullable_col@1",
+ " WindowAggExec: wdw=[count: Ok(Field { name: \"count\",
data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata:
{} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL),
end_bound: CurrentRow }]",
+ " SortExec: [non_nullable_col@1 DESC]",
+ " MemoryExec: partitions=0, partition_sizes=[]",
+ ]
+ };
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_add_required_sort() -> Result<()> {
Review Comment:
I think enforcement is the pass that is supposed to be adding required sorts
-- as above I find it confusing that "remove unnecessary sorts" is also adding
sorts 🤔
##########
datafusion/core/src/execution/context.rs:
##########
@@ -1603,6 +1604,12 @@ impl SessionState {
// To make sure the SinglePartition is satisfied, run the
BasicEnforcement again, originally it was the AddCoalescePartitionsExec here.
physical_optimizers.push(Arc::new(BasicEnforcement::new()));
+ // `BasicEnforcement` stage conservatively inserts `SortExec`s to
satisfy ordering requirements.
Review Comment:
👍
##########
datafusion/physical-expr/src/window/lead_lag.rs:
##########
@@ -107,6 +106,16 @@ impl BuiltInWindowFunctionExpr for WindowShift {
default_value: self.default_value.clone(),
}))
}
+
+ fn reverse_expr(&self) -> Option<Arc<dyn BuiltInWindowFunctionExpr>> {
+ Some(Arc::new(Self {
+ name: self.name.clone(),
+ data_type: self.data_type.clone(),
+ shift_offset: -self.shift_offset,
Review Comment:
this is nice 👌
##########
datafusion/expr/src/window_frame.rs:
##########
@@ -113,6 +113,33 @@ impl WindowFrame {
}
}
}
+
+ /// Get reversed window frame
Review Comment:
Some examples would probably help future readers here understand what is
meant by a REVERSED frame
Something like this, perhaps:
```suggestion
/// Get reversed window frame. For example
/// `3 ROWS PRECEDING AND 2 ROWS FOLLOWING` -->
/// `2 ROWS PRECEDING AND 3 ROWS FOLLOWING`
```
##########
datafusion/expr/src/utils.rs:
##########
@@ -204,7 +204,7 @@ pub fn expand_qualified_wildcard(
expand_wildcard(&qualifier_schema, plan)
}
-type WindowSortKey = Vec<Expr>;
+type WindowSortKey = Vec<(Expr, bool)>;
Review Comment:
Can we please document in docstrings what this bool represents?
Is it
```suggestion
/// (expr, "is the expr part of the window partition key")
type WindowSortKey = Vec<(Expr, bool)>;
```
##########
datafusion/core/src/physical_plan/windows/window_agg_exec.rs:
##########
@@ -131,6 +135,27 @@ impl WindowAggExec {
pub fn input_schema(&self) -> SchemaRef {
self.input_schema.clone()
}
+
+ /// Get partition keys
+ pub fn partition_by_sort_keys(&self) -> Result<Vec<PhysicalSortExpr>> {
+ let mut result = vec![];
+ // All window exprs have the same partition by, so we just use the
first one:
+ let partition_by = self.window_expr()[0].partition_by();
+ let sort_keys = self
+ .sort_keys
+ .as_ref()
+ .map_or_else(|| &[] as &[PhysicalSortExpr], |v| v.as_slice());
Review Comment:
```suggestion
let sort_keys = self
.sort_keys
.as_ref()
.map(|v| v.as_slice())
.unwrap_or(&[]);
```
Not sure if you think that is any better 🤔
##########
datafusion/physical-expr/src/aggregate/mod.rs:
##########
@@ -101,4 +101,12 @@ pub trait AggregateExpr: Send + Sync + Debug {
self
)))
}
+
+ /// Construct an expression that calculates the aggregate in reverse.
Review Comment:
```suggestion
/// Construct an expression that calculates the aggregate in reverse.
/// Typically the "reverse" expression is itself (e.g. SUM, COUNT).
/// For aggregates that do not support calculation in reverse,
/// returns None (which is the default value).
```
##########
datafusion/core/src/physical_plan/repartition.rs:
##########
@@ -289,7 +289,21 @@ impl ExecutionPlan for RepartitionExec {
}
fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
- None
+ if self.maintains_input_order() {
+ self.input().output_ordering()
+ } else {
+ None
+ }
+ }
+
+ fn maintains_input_order(&self) -> bool {
+ // We preserve ordering when input partitioning is 1
+ let n_input = match self.input().output_partitioning() {
+ Partitioning::RoundRobinBatch(n) => n,
+ Partitioning::Hash(_, n) => n,
+ Partitioning::UnknownPartitioning(n) => n,
+ };
+ n_input <= 1
Review Comment:
```suggestion
self.input().output_partitioning().partition_count() <= 1
```
##########
datafusion/physical-expr/src/window/built_in.rs:
##########
@@ -91,50 +91,51 @@ impl WindowExpr for BuiltInWindowExpr {
fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
let evaluator = self.expr.create_evaluator()?;
let num_rows = batch.num_rows();
- let partition_columns = self.partition_columns(batch)?;
- let partition_points =
- self.evaluate_partition_points(num_rows, &partition_columns)?;
-
- let results = if evaluator.uses_window_frame() {
+ if evaluator.uses_window_frame() {
let sort_options: Vec<SortOptions> =
self.order_by.iter().map(|o| o.options).collect();
let mut row_wise_results = vec![];
- for partition_range in &partition_points {
- let length = partition_range.end - partition_range.start;
- let (values, order_bys) = self
- .get_values_orderbys(&batch.slice(partition_range.start,
length))?;
- let mut window_frame_ctx =
WindowFrameContext::new(&self.window_frame);
- // We iterate on each row to calculate window frame range and
and window function result
- for idx in 0..length {
- let range = window_frame_ctx.calculate_range(
- &order_bys,
- &sort_options,
- num_rows,
- idx,
- )?;
- let range = Range {
- start: range.0,
- end: range.1,
- };
- let value = evaluator.evaluate_inside_range(&values,
range)?;
- row_wise_results.push(value.to_array());
- }
+
+ let length = batch.num_rows();
+ let (values, order_bys) = self.get_values_orderbys(batch)?;
+ let mut window_frame_ctx =
WindowFrameContext::new(&self.window_frame);
+ // We iterate on each row to calculate window frame range and and
window function result
+ for idx in 0..length {
+ let range = window_frame_ctx.calculate_range(
+ &order_bys,
+ &sort_options,
+ num_rows,
+ idx,
+ )?;
+ let value = evaluator.evaluate_inside_range(&values, range)?;
+ row_wise_results.push(value);
}
- row_wise_results
+ ScalarValue::iter_to_array(row_wise_results.into_iter())
} else if evaluator.include_rank() {
let columns = self.sort_columns(batch)?;
let sort_partition_points =
self.evaluate_partition_points(num_rows, &columns)?;
- evaluator.evaluate_with_rank(partition_points,
sort_partition_points)?
+ evaluator.evaluate_with_rank(num_rows, &sort_partition_points)
} else {
let (values, _) = self.get_values_orderbys(batch)?;
- evaluator.evaluate(&values, partition_points)?
- };
- let results = results.iter().map(|i| i.as_ref()).collect::<Vec<_>>();
- concat(&results).map_err(DataFusionError::ArrowError)
+ evaluator.evaluate(&values, num_rows)
+ }
}
fn get_window_frame(&self) -> &Arc<WindowFrame> {
&self.window_frame
}
+
+ fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>> {
+ if let Some(reverse_expr) = self.expr.reverse_expr() {
+ Some(Arc::new(BuiltInWindowExpr::new(
+ reverse_expr,
+ &self.partition_by.clone(),
+ &reverse_order_bys(&self.order_by),
+ Arc::new(self.window_frame.reverse()),
+ )))
+ } else {
+ None
+ }
Review Comment:
You could also express this in a more functional style if you wanted:
```suggestion
self.expr.reverse_expr()
.map(|reverse_expr| {
Arc::new(BuiltInWindowExpr::new(
reverse_expr,
&self.partition_by.clone(),
&reverse_order_bys(&self.order_by),
Arc::new(self.window_frame.reverse()),
)) as _
})
}
```
##########
datafusion/core/src/physical_plan/windows/window_agg_exec.rs:
##########
@@ -368,16 +397,58 @@ impl WindowAggStream {
let batch = concat_batches(&self.input.schema(), &self.batches)?;
- // calculate window cols
- let mut columns = compute_window_aggregates(&self.window_expr, &batch)
- .map_err(|e| ArrowError::ExternalError(Box::new(e)))?;
+ let partition_by_sort_keys = self
+ .partition_by_sort_keys
+ .iter()
+ .map(|elem| elem.evaluate_to_sort_column(&batch))
+ .collect::<Result<Vec<_>>>()?;
+ let partition_points =
+ self.evaluate_partition_points(batch.num_rows(),
&partition_by_sort_keys)?;
+
+ let mut partition_results = vec![];
+ // Calculate window cols
+ for partition_point in partition_points {
+ let length = partition_point.end - partition_point.start;
+ partition_results.push(
+ compute_window_aggregates(
+ &self.window_expr,
+ &batch.slice(partition_point.start, length),
+ )
+ .map_err(|e| ArrowError::ExternalError(Box::new(e)))?,
+ )
+ }
+ let mut columns = transpose(partition_results)
Review Comment:
Am I correct in my understanding that this code is effectively creating new
`ArrayRefs` for each partition (rather than passing in (`usize`, `usize`)
boundaries for each partition?
##########
datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs:
##########
@@ -0,0 +1,887 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! RemoveUnnecessarySorts optimizer rule inspects SortExec's in the given
+//! physical plan and removes the ones it can prove unnecessary. The rule can
+//! work on valid *and* invalid physical plans with respect to sorting
+//! requirements, but always produces a valid physical plan in this sense.
+//!
+//! A non-realistic but easy to follow example: Assume that we somehow get the
fragment
+//! "SortExec: [nullable_col@0 ASC]",
+//! " SortExec: [non_nullable_col@1 ASC]",
+//! in the physical plan. The first sort is unnecessary since its result is
overwritten
+//! by another SortExec. Therefore, this rule removes it from the physical
plan.
+use crate::error::Result;
+use crate::physical_optimizer::utils::{
+ add_sort_above_child, ordering_satisfy, ordering_satisfy_concrete,
+};
+use crate::physical_optimizer::PhysicalOptimizerRule;
+use crate::physical_plan::rewrite::TreeNodeRewritable;
+use crate::physical_plan::sorts::sort::SortExec;
+use crate::physical_plan::windows::WindowAggExec;
+use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan};
+use crate::prelude::SessionConfig;
+use arrow::datatypes::SchemaRef;
+use datafusion_common::{reverse_sort_options, DataFusionError};
+use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
+use itertools::izip;
+use std::iter::zip;
+use std::sync::Arc;
+
+/// This rule inspects SortExec's in the given physical plan and removes the
+/// ones it can prove unnecessary.
+#[derive(Default)]
+pub struct RemoveUnnecessarySorts {}
+
+impl RemoveUnnecessarySorts {
+ #[allow(missing_docs)]
+ pub fn new() -> Self {
+ Self {}
+ }
+}
+
+/// This is a "data class" we use within the [RemoveUnnecessarySorts] rule
+/// that tracks the closest `SortExec` descendant for every child of a plan.
Review Comment:
I wonder if we could use `ExecutionPlan::output_ordering()` rather than
tracking the state separately?
That way the logic of which nodes retain what ordering is kept within the
execution plan nodes themselves rather than in this pass.
Though it looks strange to me to have both
`ExecutionPlan::maintains_input_order` as well as `ExecutionPlan::output_order`
(I think the maintains_input_order came before) -- maybe we should remove
`ExecutionPlan::maintains_input_order` 🤔 -- I could try this if you think it
worthwhile
##########
datafusion/core/tests/sql/window.rs:
##########
@@ -1748,17 +1748,20 @@ async fn test_window_partition_by_order_by() ->
Result<()> {
let msg = format!("Creating logical plan for '{}'", sql);
let dataframe = ctx.sql(sql).await.expect(&msg);
- let physical_plan = dataframe.create_physical_plan().await.unwrap();
+ let physical_plan = dataframe.create_physical_plan().await?;
let formatted = displayable(physical_plan.as_ref()).indent().to_string();
- // Only 1 SortExec was added
let expected = {
vec![
- "ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY
[aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2
ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as
SUM(aggregate_test_100.c4), COUNT(UInt8(1)) PARTITION BY
[aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS
BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as COUNT(UInt8(1))]",
- " WindowAggExec: wdw=[SUM(aggregate_test_100.c4): Ok(Field {
name: \"SUM(aggregate_test_100.c4)\", data_type: Int64, nullable: true,
dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame {
units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1))
}, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64,
nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame:
WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound:
Following(UInt64(1)) }]",
- " SortExec: [c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST]",
+ "ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY
[aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2
ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as
SUM(aggregate_test_100.c4), COUNT(UInt8(1)) PARTITION BY
[aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS
BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as COUNT(UInt8(1))]",
+ " WindowAggExec: wdw=[COUNT(UInt8(1)): Ok(Field { name:
\"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0,
dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows,
start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }]",
+ " SortExec: [c1@1 ASC NULLS LAST,c2@2 ASC NULLS LAST]",
" CoalesceBatchesExec: target_batch_size=4096",
- " RepartitionExec: partitioning=Hash([Column { name:
\"c1\", index: 0 }], 2)",
- " RepartitionExec: partitioning=RoundRobinBatch(2)",
+ " RepartitionExec: partitioning=Hash([Column { name:
\"c1\", index: 1 }], 2)",
+ " WindowAggExec: wdw=[SUM(aggregate_test_100.c4):
Ok(Field { name: \"SUM(aggregate_test_100.c4)\", data_type: Int64, nullable:
true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame {
units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1))
}]",
+ " SortExec: [c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST]",
Review Comment:
This plan seems worse than the original, doesn't it? It now sorts the data
twice (both times on `c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST`).
Maybe I am missing something 🤔
##########
datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs:
##########
@@ -0,0 +1,887 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! RemoveUnnecessarySorts optimizer rule inspects SortExec's in the given
+//! physical plan and removes the ones it can prove unnecessary. The rule can
+//! work on valid *and* invalid physical plans with respect to sorting
+//! requirements, but always produces a valid physical plan in this sense.
+//!
+//! A non-realistic but easy to follow example: Assume that we somehow get the
fragment
+//! "SortExec: [nullable_col@0 ASC]",
+//! " SortExec: [non_nullable_col@1 ASC]",
+//! in the physical plan. The first sort is unnecessary since its result is
overwritten
+//! by another SortExec. Therefore, this rule removes it from the physical
plan.
+use crate::error::Result;
+use crate::physical_optimizer::utils::{
+ add_sort_above_child, ordering_satisfy, ordering_satisfy_concrete,
+};
+use crate::physical_optimizer::PhysicalOptimizerRule;
+use crate::physical_plan::rewrite::TreeNodeRewritable;
+use crate::physical_plan::sorts::sort::SortExec;
+use crate::physical_plan::windows::WindowAggExec;
+use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan};
+use crate::prelude::SessionConfig;
+use arrow::datatypes::SchemaRef;
+use datafusion_common::{reverse_sort_options, DataFusionError};
+use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
+use itertools::izip;
+use std::iter::zip;
+use std::sync::Arc;
+
+/// This rule inspects SortExec's in the given physical plan and removes the
+/// ones it can prove unnecessary.
+#[derive(Default)]
+pub struct RemoveUnnecessarySorts {}
+
+impl RemoveUnnecessarySorts {
+ #[allow(missing_docs)]
+ pub fn new() -> Self {
+ Self {}
+ }
+}
+
+/// This is a "data class" we use within the [RemoveUnnecessarySorts] rule
+/// that tracks the closest `SortExec` descendant for every child of a plan.
+#[derive(Debug, Clone)]
+struct PlanWithCorrespondingSort {
+ plan: Arc<dyn ExecutionPlan>,
+ // For every child, keep a vector of `ExecutionPlan`s starting from the
+ // closest `SortExec` till the current plan. The first index of the tuple
is
+ // the child index of the plan -- we need this information as we make
updates.
+ sort_onwards: Vec<Vec<(usize, Arc<dyn ExecutionPlan>)>>,
+}
+
+impl PlanWithCorrespondingSort {
+ pub fn new(plan: Arc<dyn ExecutionPlan>) -> Self {
+ let length = plan.children().len();
+ PlanWithCorrespondingSort {
+ plan,
+ sort_onwards: vec![vec![]; length],
+ }
+ }
+
+ pub fn children(&self) -> Vec<PlanWithCorrespondingSort> {
+ self.plan
+ .children()
+ .into_iter()
+ .map(|child| PlanWithCorrespondingSort::new(child))
+ .collect()
+ }
+}
+
+impl TreeNodeRewritable for PlanWithCorrespondingSort {
+ fn map_children<F>(self, transform: F) -> Result<Self>
+ where
+ F: FnMut(Self) -> Result<Self>,
+ {
+ let children = self.children();
+ if children.is_empty() {
+ Ok(self)
+ } else {
+ let children_requirements = children
+ .into_iter()
+ .map(transform)
+ .collect::<Result<Vec<_>>>()?;
+ let children_plans = children_requirements
+ .iter()
+ .map(|elem| elem.plan.clone())
+ .collect::<Vec<_>>();
+ let sort_onwards = children_requirements
+ .iter()
+ .map(|item| {
+ if item.sort_onwards.is_empty() {
+ vec![]
+ } else {
+ // TODO: When `maintains_input_order` returns
Vec<bool>,
+ // pass the order-enforcing sort upwards.
+ item.sort_onwards[0].clone()
+ }
+ })
+ .collect::<Vec<_>>();
+ let plan = with_new_children_if_necessary(self.plan,
children_plans)?;
+ Ok(PlanWithCorrespondingSort { plan, sort_onwards })
+ }
+ }
+}
+
+impl PhysicalOptimizerRule for RemoveUnnecessarySorts {
+ fn optimize(
+ &self,
+ plan: Arc<dyn ExecutionPlan>,
+ _config: &SessionConfig,
+ ) -> Result<Arc<dyn ExecutionPlan>> {
+ // Execute a post-order traversal to adjust input key ordering:
+ let plan_requirements = PlanWithCorrespondingSort::new(plan);
+ let adjusted =
plan_requirements.transform_up(&remove_unnecessary_sorts)?;
+ Ok(adjusted.plan)
+ }
+
+ fn name(&self) -> &str {
+ "RemoveUnnecessarySorts"
+ }
+
+ fn schema_check(&self) -> bool {
+ true
+ }
+}
+
+fn remove_unnecessary_sorts(
+ requirements: PlanWithCorrespondingSort,
+) -> Result<Option<PlanWithCorrespondingSort>> {
+ // Perform naive analysis at the beginning -- remove already-satisfied
sorts:
+ if let Some(result) = analyze_immediate_sort_removal(&requirements)? {
+ return Ok(Some(result));
+ }
+ let plan = &requirements.plan;
+ let mut new_children = plan.children().clone();
+ let mut new_onwards = requirements.sort_onwards.clone();
+ for (idx, (child, sort_onwards, required_ordering)) in izip!(
+ new_children.iter_mut(),
+ new_onwards.iter_mut(),
+ plan.required_input_ordering()
+ )
+ .enumerate()
+ {
+ let physical_ordering = child.output_ordering();
+ match (required_ordering, physical_ordering) {
+ (Some(required_ordering), Some(physical_ordering)) => {
+ let is_ordering_satisfied = ordering_satisfy_concrete(
+ physical_ordering,
+ required_ordering,
+ || child.equivalence_properties(),
+ );
+ if !is_ordering_satisfied {
+ // Make sure we preserve the ordering requirements:
+ update_child_to_remove_unnecessary_sort(child,
sort_onwards)?;
+ let sort_expr = required_ordering.to_vec();
+ *child = add_sort_above_child(child, sort_expr)?;
Review Comment:
It is somewhat strange to me that a pass called "RemoveUnecessarySorts" is
actually sometimes adding sorts 🤔
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]