kumarUjjawal commented on code in PR #21621: URL: https://github.com/apache/datafusion/pull/21621#discussion_r3246261643
########## datafusion/optimizer/src/push_down_topk_through_join.rs: ########## @@ -0,0 +1,1138 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`PushDownTopKThroughJoin`] pushes TopK (Sort with fetch) through outer joins +//! +//! When a `Sort` with a fetch limit sits above an outer join and all sort +//! expressions come from the **preserved** side, this rule inserts a copy +//! of the `Sort(fetch)` on that input to reduce the number of rows +//! entering the join. +//! +//! This is correct because: +//! - A LEFT JOIN preserves every left row (each appears at least once in the +//! output). The final top-N by left-side columns must come from the top-N +//! left rows. +//! - The same reasoning applies symmetrically for RIGHT JOIN and right-side +//! columns. +//! +//! The top-level sort is kept for correctness since a 1-to-many join can +//! produce more than N output rows from N input rows. +//! +//! ## Example +//! +//! Before: +//! ```text +//! Sort: t1.b ASC, fetch=3 +//! Left Join: t1.a = t2.c +//! Scan: t1 ← scans ALL rows +//! Scan: t2 +//! ``` +//! +//! After: +//! ```text +//! Sort: t1.b ASC, fetch=3 +//! Left Join: t1.a = t2.c +//! Sort: t1.b ASC, fetch=3 ← pushed down +//! Scan: t1 +//! Scan: t2 +//! ``` + +use std::sync::Arc; + +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; + +use crate::utils::{has_all_column_refs, schema_columns}; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{Column, Result}; +use datafusion_expr::logical_plan::{ + JoinType, LogicalPlan, Projection, Sort as SortPlan, SubqueryAlias, +}; +use datafusion_expr::{Expr, SortExpr}; + +/// Optimization rule that pushes TopK (Sort with fetch) through +/// LEFT / RIGHT outer joins when all sort expressions come from +/// the preserved side. +/// +/// See module-level documentation for details. +#[derive(Default, Debug)] +pub struct PushDownTopKThroughJoin; + +impl PushDownTopKThroughJoin { + #[expect(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for PushDownTopKThroughJoin { + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result<Transformed<LogicalPlan>> { + // Match Sort with fetch (TopK) + let LogicalPlan::Sort(sort) = &plan else { + return Ok(Transformed::no(plan)); + }; + let Some(fetch) = sort.fetch else { + return Ok(Transformed::no(plan)); + }; + + // Don't push if any sort expression is non-deterministic (e.g. random()). + // Duplicating such expressions would produce different values at each + // evaluation point, potentially changing the result. + if sort.expr.iter().any(|se| se.expr.is_volatile()) { + return Ok(Transformed::no(plan)); + } + + // Peel through transparent nodes (SubqueryAlias, Projection) to find + // the Join. Track intermediate nodes so we can reconstruct the tree + // and resolve sort expressions through them. + let mut current = sort.input.as_ref(); + let mut intermediates: Vec<&LogicalPlan> = Vec::new(); + let join = loop { + match current { + LogicalPlan::Join(join) => break join, + LogicalPlan::Projection(proj) => { + intermediates.push(current); + current = proj.input.as_ref(); + } + LogicalPlan::SubqueryAlias(sq) => { + intermediates.push(current); + current = sq.input.as_ref(); + } + _ => return Ok(Transformed::no(plan)), + } + }; + + // Only outer joins where the preserved side is known. + // Semi/Anti joins are excluded: not all preserved-side rows appear in + // the output (only matched/unmatched rows do), so pushing fetch=N to + // the preserved child can drop rows that would have survived the filter. + // + // Non-equijoin filters in the ON clause are safe: outer joins guarantee + // all preserved-side rows appear in the output regardless of the filter. + // The filter only controls matching (which non-preserved rows pair up), + // not which preserved rows survive. + let preserved_is_left = match join.join_type { + JoinType::Left => true, + JoinType::Right => false, + _ => return Ok(Transformed::no(plan)), + }; + + // Resolve sort expressions through all intermediate nodes (Projection, + // SubqueryAlias) so that column references match the join's schema. + let mut resolved_sort_exprs = sort.expr.clone(); + for node in &intermediates { + match node { + LogicalPlan::Projection(proj) => { + resolved_sort_exprs = resolve_sort_exprs_through_projection( + &resolved_sort_exprs, + proj, + )?; + } + LogicalPlan::SubqueryAlias(sq) => { + resolved_sort_exprs = resolve_sort_exprs_through_subquery_alias( + &resolved_sort_exprs, + sq, + )?; + } + _ => unreachable!(), + } + } + + // After resolving through projections, the sort expressions may now + // contain volatile functions (e.g. `random() AS col`). Duplicating + // volatile expressions in the pushed Sort would produce different + // values, changing results. + if resolved_sort_exprs.iter().any(|se| se.expr.is_volatile()) { + return Ok(Transformed::no(plan)); + } + + let preserved_schema = if preserved_is_left { + join.left.schema() + } else { + join.right.schema() + }; + let preserved_cols = schema_columns(preserved_schema); + + let all_from_preserved = resolved_sort_exprs + .iter() + .all(|sort_expr| has_all_column_refs(&sort_expr.expr, &preserved_cols)); + if !all_from_preserved { + return Ok(Transformed::no(plan)); + } + + let preserved_child = if preserved_is_left { + &join.left + } else { + &join.right + }; + + // Scan deep inside the preserved child (through SubqueryAlias and + // Projection layers) to find an existing Sort. If found with same + // exprs, tighten its fetch in-place. Otherwise, insert a new Sort + // directly below the join as the preserved child's wrapper. + let mut inner_child = preserved_child.as_ref(); + let mut deep_resolved_exprs = resolved_sort_exprs.clone(); + loop { + match inner_child { + LogicalPlan::SubqueryAlias(sq) => { + deep_resolved_exprs = resolve_sort_exprs_through_subquery_alias( + &deep_resolved_exprs, + sq, + )?; + inner_child = sq.input.as_ref(); + } + LogicalPlan::Projection(proj) => { + deep_resolved_exprs = resolve_sort_exprs_through_projection( + &deep_resolved_exprs, + proj, + )?; + inner_child = proj.input.as_ref(); + } + _ => break, + } + } + + // If the inner child is a Limit (PushDownLimit hasn't merged it with + // the Sort yet), skip this iteration. PushDownLimit will merge + // Limit → Sort in the next pass, then our rule will tighten the Sort. + if matches!(inner_child, LogicalPlan::Limit(_)) { + return Ok(Transformed::no(plan)); + } + + // Determine action based on existing inner Sort: + // - Same exprs, tighter fetch → skip (already optimal) + // - Same exprs, larger/no fetch → tighten in-place + // - Different exprs or no Sort → insert new Sort below the join + // + // Example (tighten): Sort(a ASC, fetch=5) → Join → Sort(a ASC, fetch=10) + // Child limits to 10, our tighter fetch=5 tightens it in-place. + // + // Example (tighten): Sort(a ASC, fetch=5) → Join → Sort(a ASC) + // Child has no fetch (full sort), tighten to fetch=5. + // + // Example (skip): Sort(a ASC, fetch=5) → Join → Sort(a ASC, fetch=3) + // Child already limits to 3 rows, pushing fetch=5 won't help. + // + // Example (new): Sort(b ASC, fetch=5) → Join → Sort(a ASC, fetch=10) + // Different exprs, insert Sort(b, fetch=5) above preserved child. + let new_preserved_child = if let LogicalPlan::Sort(child_sort) = inner_child { + let same_exprs = sort_exprs_equal(&child_sort.expr, &deep_resolved_exprs); + let child_fetch_tighter = match child_sort.fetch { + Some(child_fetch) => child_fetch <= fetch, + None => false, + }; + if same_exprs && child_fetch_tighter { + return Ok(Transformed::no(plan)); + } + if same_exprs { + // Tighten existing Sort in-place by rebuilding the path + // from preserved child down to the Sort. + rebuild_with_tightened_sort(preserved_child.as_ref(), child_sort, fetch)? + } else { + // Different exprs — insert new Sort above the preserved child. + // If the inner Sort has no fetch, our pushed Sort is the only + // row reduction. If it has a fetch, re-sorting a small set is + // cheap and still reduces rows entering the join. + Arc::new(LogicalPlan::Sort(SortPlan { + expr: resolved_sort_exprs, + input: Arc::clone(preserved_child), + fetch: Some(fetch), + })) + } + } else { + // No existing Sort — insert new Sort below the join. + Arc::new(LogicalPlan::Sort(SortPlan { + expr: resolved_sort_exprs, + input: Arc::clone(preserved_child), + fetch: Some(fetch), + })) + }; + + // Reconstruct the join with the new child + let mut new_join = join.clone(); + if preserved_is_left { + new_join.left = new_preserved_child; + } else { + new_join.right = new_preserved_child; + } + + // Rebuild the tree: join → intermediate nodes → top-level sort + let mut new_sort_input = Arc::new(LogicalPlan::Join(new_join)); + for node in intermediates.into_iter().rev() { + new_sort_input = Arc::new(match node { + LogicalPlan::Projection(proj) => { + let mut new_proj = proj.clone(); + new_proj.input = new_sort_input; + LogicalPlan::Projection(new_proj) + } + LogicalPlan::SubqueryAlias(sq) => LogicalPlan::SubqueryAlias( + SubqueryAlias::try_new(new_sort_input, sq.alias.clone())?, + ), + _ => unreachable!(), + }); + } + + Ok(Transformed::yes(LogicalPlan::Sort(SortPlan { + expr: sort.expr.clone(), + input: new_sort_input, + fetch: sort.fetch, + }))) + } + + fn name(&self) -> &str { + "push_down_topk_through_join" + } + + fn apply_order(&self) -> Option<ApplyOrder> { + Some(ApplyOrder::TopDown) + } +} + +/// Resolve sort expressions through a projection by replacing column +/// references with the underlying projection expressions. +/// +/// For example, if sort expr is `b ASC` and projection has `-t1.b AS b`, +/// the resolved sort expr becomes `-t1.b ASC`. +/// +/// Before: +/// ```text +/// Sort: b ASC, fetch=3 +/// Projection: -t1.b AS b +/// Join +/// t1 +/// t2 +/// ``` +/// +/// After resolving, the pushed Sort uses pre-projection expressions: +/// ```text +/// Sort: b ASC, fetch=3 +/// Projection: -t1.b AS b +/// Join +/// Sort: -t1.b ASC, fetch=3 ← resolved through projection +/// t1 +/// t2 +/// ``` Review Comment: I think this doc describes different function, should be removed. ########## datafusion/optimizer/src/push_down_topk_through_join.rs: ########## @@ -0,0 +1,1138 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`PushDownTopKThroughJoin`] pushes TopK (Sort with fetch) through outer joins +//! +//! When a `Sort` with a fetch limit sits above an outer join and all sort +//! expressions come from the **preserved** side, this rule inserts a copy +//! of the `Sort(fetch)` on that input to reduce the number of rows +//! entering the join. +//! +//! This is correct because: +//! - A LEFT JOIN preserves every left row (each appears at least once in the +//! output). The final top-N by left-side columns must come from the top-N +//! left rows. +//! - The same reasoning applies symmetrically for RIGHT JOIN and right-side +//! columns. +//! +//! The top-level sort is kept for correctness since a 1-to-many join can +//! produce more than N output rows from N input rows. +//! +//! ## Example +//! +//! Before: +//! ```text +//! Sort: t1.b ASC, fetch=3 +//! Left Join: t1.a = t2.c +//! Scan: t1 ← scans ALL rows +//! Scan: t2 +//! ``` +//! +//! After: +//! ```text +//! Sort: t1.b ASC, fetch=3 +//! Left Join: t1.a = t2.c +//! Sort: t1.b ASC, fetch=3 ← pushed down +//! Scan: t1 +//! Scan: t2 +//! ``` + +use std::sync::Arc; + +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; + +use crate::utils::{has_all_column_refs, schema_columns}; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{Column, Result}; +use datafusion_expr::logical_plan::{ + JoinType, LogicalPlan, Projection, Sort as SortPlan, SubqueryAlias, +}; +use datafusion_expr::{Expr, SortExpr}; + +/// Optimization rule that pushes TopK (Sort with fetch) through +/// LEFT / RIGHT outer joins when all sort expressions come from +/// the preserved side. +/// +/// See module-level documentation for details. +#[derive(Default, Debug)] +pub struct PushDownTopKThroughJoin; + +impl PushDownTopKThroughJoin { + #[expect(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for PushDownTopKThroughJoin { + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result<Transformed<LogicalPlan>> { + // Match Sort with fetch (TopK) + let LogicalPlan::Sort(sort) = &plan else { + return Ok(Transformed::no(plan)); + }; + let Some(fetch) = sort.fetch else { + return Ok(Transformed::no(plan)); + }; + + // Don't push if any sort expression is non-deterministic (e.g. random()). + // Duplicating such expressions would produce different values at each + // evaluation point, potentially changing the result. + if sort.expr.iter().any(|se| se.expr.is_volatile()) { + return Ok(Transformed::no(plan)); + } + + // Peel through transparent nodes (SubqueryAlias, Projection) to find + // the Join. Track intermediate nodes so we can reconstruct the tree + // and resolve sort expressions through them. + let mut current = sort.input.as_ref(); + let mut intermediates: Vec<&LogicalPlan> = Vec::new(); + let join = loop { + match current { + LogicalPlan::Join(join) => break join, + LogicalPlan::Projection(proj) => { + intermediates.push(current); + current = proj.input.as_ref(); + } + LogicalPlan::SubqueryAlias(sq) => { + intermediates.push(current); + current = sq.input.as_ref(); + } + _ => return Ok(Transformed::no(plan)), + } + }; + + // Only outer joins where the preserved side is known. + // Semi/Anti joins are excluded: not all preserved-side rows appear in + // the output (only matched/unmatched rows do), so pushing fetch=N to + // the preserved child can drop rows that would have survived the filter. + // + // Non-equijoin filters in the ON clause are safe: outer joins guarantee + // all preserved-side rows appear in the output regardless of the filter. + // The filter only controls matching (which non-preserved rows pair up), + // not which preserved rows survive. + let preserved_is_left = match join.join_type { + JoinType::Left => true, + JoinType::Right => false, + _ => return Ok(Transformed::no(plan)), + }; + + // Resolve sort expressions through all intermediate nodes (Projection, + // SubqueryAlias) so that column references match the join's schema. + let mut resolved_sort_exprs = sort.expr.clone(); + for node in &intermediates { + match node { + LogicalPlan::Projection(proj) => { + resolved_sort_exprs = resolve_sort_exprs_through_projection( + &resolved_sort_exprs, + proj, + )?; + } + LogicalPlan::SubqueryAlias(sq) => { + resolved_sort_exprs = resolve_sort_exprs_through_subquery_alias( + &resolved_sort_exprs, + sq, + )?; + } + _ => unreachable!(), + } + } + + // After resolving through projections, the sort expressions may now + // contain volatile functions (e.g. `random() AS col`). Duplicating + // volatile expressions in the pushed Sort would produce different + // values, changing results. + if resolved_sort_exprs.iter().any(|se| se.expr.is_volatile()) { + return Ok(Transformed::no(plan)); + } + + let preserved_schema = if preserved_is_left { + join.left.schema() + } else { + join.right.schema() + }; + let preserved_cols = schema_columns(preserved_schema); + + let all_from_preserved = resolved_sort_exprs + .iter() + .all(|sort_expr| has_all_column_refs(&sort_expr.expr, &preserved_cols)); + if !all_from_preserved { + return Ok(Transformed::no(plan)); + } + + let preserved_child = if preserved_is_left { + &join.left + } else { + &join.right + }; + + // Scan deep inside the preserved child (through SubqueryAlias and + // Projection layers) to find an existing Sort. If found with same + // exprs, tighten its fetch in-place. Otherwise, insert a new Sort + // directly below the join as the preserved child's wrapper. + let mut inner_child = preserved_child.as_ref(); + let mut deep_resolved_exprs = resolved_sort_exprs.clone(); + loop { + match inner_child { + LogicalPlan::SubqueryAlias(sq) => { + deep_resolved_exprs = resolve_sort_exprs_through_subquery_alias( + &deep_resolved_exprs, + sq, + )?; + inner_child = sq.input.as_ref(); + } + LogicalPlan::Projection(proj) => { + deep_resolved_exprs = resolve_sort_exprs_through_projection( + &deep_resolved_exprs, + proj, + )?; + inner_child = proj.input.as_ref(); + } + _ => break, + } + } + + // If the inner child is a Limit (PushDownLimit hasn't merged it with + // the Sort yet), skip this iteration. PushDownLimit will merge + // Limit → Sort in the next pass, then our rule will tighten the Sort. + if matches!(inner_child, LogicalPlan::Limit(_)) { + return Ok(Transformed::no(plan)); + } + + // Determine action based on existing inner Sort: + // - Same exprs, tighter fetch → skip (already optimal) + // - Same exprs, larger/no fetch → tighten in-place + // - Different exprs or no Sort → insert new Sort below the join + // + // Example (tighten): Sort(a ASC, fetch=5) → Join → Sort(a ASC, fetch=10) + // Child limits to 10, our tighter fetch=5 tightens it in-place. + // + // Example (tighten): Sort(a ASC, fetch=5) → Join → Sort(a ASC) + // Child has no fetch (full sort), tighten to fetch=5. + // + // Example (skip): Sort(a ASC, fetch=5) → Join → Sort(a ASC, fetch=3) + // Child already limits to 3 rows, pushing fetch=5 won't help. + // + // Example (new): Sort(b ASC, fetch=5) → Join → Sort(a ASC, fetch=10) + // Different exprs, insert Sort(b, fetch=5) above preserved child. + let new_preserved_child = if let LogicalPlan::Sort(child_sort) = inner_child { + let same_exprs = sort_exprs_equal(&child_sort.expr, &deep_resolved_exprs); + let child_fetch_tighter = match child_sort.fetch { + Some(child_fetch) => child_fetch <= fetch, + None => false, + }; + if same_exprs && child_fetch_tighter { + return Ok(Transformed::no(plan)); + } + if same_exprs { + // Tighten existing Sort in-place by rebuilding the path + // from preserved child down to the Sort. + rebuild_with_tightened_sort(preserved_child.as_ref(), child_sort, fetch)? + } else { + // Different exprs — insert new Sort above the preserved child. + // If the inner Sort has no fetch, our pushed Sort is the only + // row reduction. If it has a fetch, re-sorting a small set is + // cheap and still reduces rows entering the join. + Arc::new(LogicalPlan::Sort(SortPlan { + expr: resolved_sort_exprs, + input: Arc::clone(preserved_child), + fetch: Some(fetch), + })) + } + } else { + // No existing Sort — insert new Sort below the join. + Arc::new(LogicalPlan::Sort(SortPlan { + expr: resolved_sort_exprs, + input: Arc::clone(preserved_child), + fetch: Some(fetch), + })) + }; + + // Reconstruct the join with the new child + let mut new_join = join.clone(); + if preserved_is_left { + new_join.left = new_preserved_child; + } else { + new_join.right = new_preserved_child; + } + + // Rebuild the tree: join → intermediate nodes → top-level sort + let mut new_sort_input = Arc::new(LogicalPlan::Join(new_join)); + for node in intermediates.into_iter().rev() { + new_sort_input = Arc::new(match node { + LogicalPlan::Projection(proj) => { + let mut new_proj = proj.clone(); + new_proj.input = new_sort_input; + LogicalPlan::Projection(new_proj) + } + LogicalPlan::SubqueryAlias(sq) => LogicalPlan::SubqueryAlias( + SubqueryAlias::try_new(new_sort_input, sq.alias.clone())?, + ), + _ => unreachable!(), + }); + } + + Ok(Transformed::yes(LogicalPlan::Sort(SortPlan { + expr: sort.expr.clone(), + input: new_sort_input, + fetch: sort.fetch, + }))) + } + + fn name(&self) -> &str { + "push_down_topk_through_join" + } + + fn apply_order(&self) -> Option<ApplyOrder> { + Some(ApplyOrder::TopDown) + } +} + +/// Resolve sort expressions through a projection by replacing column +/// references with the underlying projection expressions. +/// +/// For example, if sort expr is `b ASC` and projection has `-t1.b AS b`, +/// the resolved sort expr becomes `-t1.b ASC`. +/// +/// Before: +/// ```text +/// Sort: b ASC, fetch=3 +/// Projection: -t1.b AS b +/// Join +/// t1 +/// t2 +/// ``` +/// +/// After resolving, the pushed Sort uses pre-projection expressions: +/// ```text +/// Sort: b ASC, fetch=3 +/// Projection: -t1.b AS b +/// Join +/// Sort: -t1.b ASC, fetch=3 ← resolved through projection +/// t1 +/// t2 +/// ``` +/// Replace column references in sort expressions using a name→expr map. +/// Uses `transform()` for deep replacement (handles nested expressions +/// like `-t1.b` where the column is inside a Negative wrapper). +/// +/// Example with `replace_map = {"sub.b" → Column(t1.b)}`: +/// +/// ```text +/// Input: [sub.b ASC] → Output: [t1.b ASC] (simple column) +/// Input: [(- sub.b) ASC] → Output: [(- t1.b) ASC] (nested column) +/// Input: [sub.a ASC, sub.b ASC] → Output: [t1.a ASC, t1.b ASC] (multiple) +/// ``` +fn replace_columns_in_sort_exprs( + sort_exprs: &[SortExpr], + replace_map: &std::collections::HashMap<String, Expr>, +) -> Result<Vec<SortExpr>> { + sort_exprs + .iter() + .map(|sort_expr| { + let new_expr = sort_expr.expr.clone().transform(|expr| { + let replacement = match &expr { + Expr::Column(col) => replace_map.get(&col.flat_name()).cloned(), + _ => None, + }; + Ok(replacement.map_or_else(|| Transformed::no(expr), Transformed::yes)) + })?; + Ok(SortExpr { + expr: new_expr.data, + ..*sort_expr + }) + }) + .collect() +} + +/// Resolve sort expressions through a projection by replacing column +/// references with the underlying projection expressions. +/// +/// Example: sort expr is `neg_b ASC` and projection has `-t1.b AS neg_b`: +/// +/// ```text +/// Input sort exprs: [neg_b ASC] +/// Output sort exprs: [(- t1.b) ASC] +/// ``` +fn resolve_sort_exprs_through_projection( + sort_exprs: &[SortExpr], + projection: &Projection, +) -> Result<Vec<SortExpr>> { + let replace_map: std::collections::HashMap<String, Expr> = projection + .schema + .iter() + .zip(projection.expr.iter()) + .map(|((qualifier, field), expr)| { + let key = Column::from((qualifier, field)).flat_name(); + (key, expr.clone().unalias()) + }) + .collect(); + + replace_columns_in_sort_exprs(sort_exprs, &replace_map) +} + +/// Compare two slices of `SortExpr` for equality. +/// +/// Uses structural equality on the sort expressions (direction, nulls_first, +/// and the expression tree). +fn sort_exprs_equal(a: &[SortExpr], b: &[SortExpr]) -> bool { + a.len() == b.len() + && a.iter().zip(b.iter()).all(|(left, right)| { + left.asc == right.asc + && left.nulls_first == right.nulls_first + && left.expr == right.expr + }) +} + +/// Resolve sort expressions through a SubqueryAlias by replacing the alias +/// qualifier with the input schema's qualifier. +/// +/// Example: SubqueryAlias is `sub` wrapping a join whose left input is `t1`: +/// +/// ```text +/// Input sort exprs: [sub.b ASC] +/// Output sort exprs: [t1.b ASC] +/// ``` +fn resolve_sort_exprs_through_subquery_alias( + sort_exprs: &[SortExpr], + subquery_alias: &SubqueryAlias, +) -> Result<Vec<SortExpr>> { + let replace_map: std::collections::HashMap<String, Expr> = subquery_alias + .schema + .iter() + .zip(subquery_alias.input.schema().iter()) + .map(|((alias_qual, alias_field), (input_qual, input_field))| { + let alias_col = Column::from((alias_qual, alias_field)); + let input_col = Column::from((input_qual, input_field)); + (alias_col.flat_name(), Expr::Column(input_col)) + }) + .collect(); + + replace_columns_in_sort_exprs(sort_exprs, &replace_map) +} + +/// Rebuild the tree from `root` down to an existing Sort, tightening the +/// Sort's fetch to `new_fetch`. The path from `root` to the target Sort +/// may contain Projections and SubqueryAliases. +/// +/// Before (new_fetch=2): +/// ```text +/// SubqueryAlias(t1) +/// Projection(a, b AS renamed_b) +/// Sort(t1.b ASC, fetch=10) ← target, fetch too large +/// TableScan: t1 +/// ``` +/// +/// After: +/// ```text +/// SubqueryAlias(t1) ← rebuilt +/// Projection(a, b AS renamed_b) ← rebuilt +/// Sort(t1.b ASC, fetch=2) ← tightened +/// TableScan: t1 +/// ``` +fn rebuild_with_tightened_sort( + root: &LogicalPlan, + target_sort: &SortPlan, + new_fetch: usize, +) -> Result<Arc<LogicalPlan>> { + match root { + LogicalPlan::Sort(s) if std::ptr::eq(s, target_sort) => { + Ok(Arc::new(LogicalPlan::Sort(SortPlan { + expr: s.expr.clone(), + input: Arc::clone(&s.input), + fetch: Some(new_fetch), + }))) + } + LogicalPlan::Projection(proj) => { + let new_input = + rebuild_with_tightened_sort(proj.input.as_ref(), target_sort, new_fetch)?; + let mut new_proj = proj.clone(); + new_proj.input = new_input; + Ok(Arc::new(LogicalPlan::Projection(new_proj))) + } + LogicalPlan::SubqueryAlias(sq) => { + let new_input = + rebuild_with_tightened_sort(sq.input.as_ref(), target_sort, new_fetch)?; + Ok(Arc::new(LogicalPlan::SubqueryAlias( + SubqueryAlias::try_new(new_input, sq.alias.clone())?, + ))) + } + _ => unreachable!("rebuild_with_tightened_sort: unexpected node"), Review Comment: If a future refactor introduces a clone along the path (e.g., a tree_node::rewrite), the match fails and the unreachable!() panics at runtime. Consider tracking the path (depth or kinds) during the deep scan and rewalking structurally, or switch to a tree_node::transform_down-based rewrite that's not identity keyed. Wdyt? ########## datafusion/optimizer/src/push_down_topk_through_join.rs: ########## @@ -0,0 +1,1138 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`PushDownTopKThroughJoin`] pushes TopK (Sort with fetch) through outer joins +//! +//! When a `Sort` with a fetch limit sits above an outer join and all sort +//! expressions come from the **preserved** side, this rule inserts a copy +//! of the `Sort(fetch)` on that input to reduce the number of rows +//! entering the join. +//! +//! This is correct because: +//! - A LEFT JOIN preserves every left row (each appears at least once in the +//! output). The final top-N by left-side columns must come from the top-N +//! left rows. +//! - The same reasoning applies symmetrically for RIGHT JOIN and right-side +//! columns. +//! +//! The top-level sort is kept for correctness since a 1-to-many join can +//! produce more than N output rows from N input rows. +//! +//! ## Example +//! +//! Before: +//! ```text +//! Sort: t1.b ASC, fetch=3 +//! Left Join: t1.a = t2.c +//! Scan: t1 ← scans ALL rows +//! Scan: t2 +//! ``` +//! +//! After: +//! ```text +//! Sort: t1.b ASC, fetch=3 +//! Left Join: t1.a = t2.c +//! Sort: t1.b ASC, fetch=3 ← pushed down +//! Scan: t1 +//! Scan: t2 +//! ``` + +use std::sync::Arc; + +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; + +use crate::utils::{has_all_column_refs, schema_columns}; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{Column, Result}; +use datafusion_expr::logical_plan::{ + JoinType, LogicalPlan, Projection, Sort as SortPlan, SubqueryAlias, +}; +use datafusion_expr::{Expr, SortExpr}; + +/// Optimization rule that pushes TopK (Sort with fetch) through +/// LEFT / RIGHT outer joins when all sort expressions come from +/// the preserved side. +/// +/// See module-level documentation for details. +#[derive(Default, Debug)] +pub struct PushDownTopKThroughJoin; + +impl PushDownTopKThroughJoin { + #[expect(missing_docs)] Review Comment: Not requred, just add a small docs -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
