SubhamSinghal commented on code in PR #21621:
URL: https://github.com/apache/datafusion/pull/21621#discussion_r3275160712


##########
datafusion/optimizer/src/push_down_topk_through_join.rs:
##########
@@ -0,0 +1,1318 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! [`PushDownTopKThroughJoin`] pushes TopK (Sort with fetch) through joins
+//! whose preserved side is known.
+//!
+//! When a `Sort` with a fetch limit sits above such a join and all sort
+//! expressions come from the **preserved** side, this rule inserts a copy
+//! of the `Sort(fetch)` on that input to reduce the number of rows
+//! entering the join.
+//!
+//! This is correct because:
+//! - A LEFT JOIN preserves every left row (each appears at least once in the
+//!   output). The final top-N by left-side columns must come from the top-N
+//!   left rows.
+//! - The same reasoning applies symmetrically for RIGHT JOIN and right-side
+//!   columns.
+//! - A CROSS JOIN preserves every row from both sides (Cartesian product).
+//!   The top-N by one side's columns must come from the top-N rows of that
+//!   side, since each surviving row is duplicated by the other side's row
+//!   count.
+//! - LEFT MARK / RIGHT MARK joins emit exactly one record per row of the
+//!   marked side (with an extra mark column), so that side is fully
+//!   preserved and pushdown applies symmetrically to LEFT / RIGHT joins.
+//!
+//! The top-level sort is kept for correctness since a 1-to-many join can
+//! produce more than N output rows from N input rows.
+//!
+//! ## Example
+//!
+//! Before:
+//! ```text
+//! Sort: t1.b ASC, fetch=3
+//!   Left Join: t1.a = t2.c
+//!     Scan: t1     ← scans ALL rows
+//!     Scan: t2
+//! ```
+//!
+//! After:
+//! ```text
+//! Sort: t1.b ASC, fetch=3
+//!   Left Join: t1.a = t2.c
+//!     Sort: t1.b ASC, fetch=3  ← pushed down
+//!       Scan: t1
+//!     Scan: t2
+//! ```
+
+use std::sync::Arc;
+
+use crate::optimizer::ApplyOrder;
+use crate::{OptimizerConfig, OptimizerRule};
+
+use crate::utils::{has_all_column_refs, schema_columns};
+use datafusion_common::tree_node::{Transformed, TreeNode};
+use datafusion_common::{Column, Result, internal_err};
+use datafusion_expr::logical_plan::{
+    JoinType, LogicalPlan, Projection, Sort as SortPlan, SubqueryAlias,
+};
+use datafusion_expr::{Expr, SortExpr};
+
+/// Which child of a join is being treated as the preserved side.
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+enum Side {
+    Left,
+    Right,
+}
+
+/// Optimization rule that pushes TopK (Sort with fetch) through joins
+/// that have a known preserved side (LEFT / RIGHT outer,
+/// LEFT MARK / RIGHT MARK, or CROSS) when all sort expressions come
+/// from a preserved side.
+///
+/// See module-level documentation for details.
+#[derive(Default, Debug)]
+pub struct PushDownTopKThroughJoin;
+
+impl PushDownTopKThroughJoin {
+    /// Create a new `PushDownTopKThroughJoin` rule.
+    pub fn new() -> Self {
+        Self {}
+    }
+}
+
+impl OptimizerRule for PushDownTopKThroughJoin {
+    fn supports_rewrite(&self) -> bool {
+        true
+    }
+
+    fn rewrite(
+        &self,
+        plan: LogicalPlan,
+        _config: &dyn OptimizerConfig,
+    ) -> Result<Transformed<LogicalPlan>> {
+        // Match Sort with fetch (TopK)
+        let LogicalPlan::Sort(sort) = &plan else {
+            return Ok(Transformed::no(plan));
+        };
+        let Some(fetch) = sort.fetch else {
+            return Ok(Transformed::no(plan));
+        };
+
+        // Don't push if any sort expression is non-deterministic (e.g. 
random()).
+        // Duplicating such expressions would produce different values at each
+        // evaluation point, potentially changing the result.
+        if sort.expr.iter().any(|se| se.expr.is_volatile()) {
+            return Ok(Transformed::no(plan));
+        }
+
+        // Peel through transparent nodes (SubqueryAlias, Projection) to find
+        // the Join. Track intermediate nodes so we can reconstruct the tree
+        // and resolve sort expressions through them.
+        let mut current = sort.input.as_ref();
+        let mut intermediates: Vec<&LogicalPlan> = Vec::new();
+        let join = loop {
+            match current {
+                LogicalPlan::Join(join) => break join,
+                LogicalPlan::Projection(proj) => {
+                    intermediates.push(current);
+                    current = proj.input.as_ref();
+                }
+                LogicalPlan::SubqueryAlias(sq) => {
+                    intermediates.push(current);
+                    current = sq.input.as_ref();
+                }
+                _ => return Ok(Transformed::no(plan)),
+            }
+        };
+
+        // Determine which side(s) of the join are preserved.
+        //
+        // - LEFT / LeftMark: only left preserved (and only left appears in
+        //   the output schema for LEFT, or left + mark column for LeftMark).
+        // - RIGHT / RightMark: symmetric.
+        // - CROSS JOIN (represented as Inner with no `on` keys and no filter):
+        //   every row from both sides appears in the output (Cartesian
+        //   product), so we can push to whichever side has all the sort cols.
+        //
+        // For LEFT/RIGHT, non-equijoin filters in the ON clause are safe:
+        // outer joins guarantee all preserved-side rows appear in the output
+        // regardless of the filter, and the non-preserved side never appears
+        // as a standalone unmatched row.
+        //
+        // For Inner joins (cross-join detection), the filter check is strict
+        // (`filter.is_none()`). When an Inner join has a filter, that filter
+        // can drop rows from either side, so pushing fetch=N may select rows
+        // that get filtered out while discarding rows that would have matched.
+        let preserved_candidates: &[Side] = match join.join_type {
+            JoinType::Left | JoinType::LeftMark => &[Side::Left],
+            JoinType::Right | JoinType::RightMark => &[Side::Right],
+            JoinType::Inner if join.on.is_empty() && join.filter.is_none() => {
+                &[Side::Left, Side::Right]
+            }
+            _ => return Ok(Transformed::no(plan)),
+        };
+
+        // Resolve sort expressions through all intermediate nodes (Projection,
+        // SubqueryAlias) so that column references match the join's schema.
+        let mut resolved_sort_exprs = sort.expr.clone();
+        for node in &intermediates {
+            match node {
+                LogicalPlan::Projection(proj) => {
+                    resolved_sort_exprs = 
resolve_sort_exprs_through_projection(
+                        &resolved_sort_exprs,
+                        proj,
+                    )?;
+                }
+                LogicalPlan::SubqueryAlias(sq) => {
+                    resolved_sort_exprs = 
resolve_sort_exprs_through_subquery_alias(
+                        &resolved_sort_exprs,
+                        sq,
+                    )?;
+                }
+                _ => {
+                    return internal_err!(
+                        "PushDownTopKThroughJoin: unexpected intermediate 
node: {}",
+                        node.display()
+                    );
+                }
+            }
+        }
+
+        // After resolving through projections, the sort expressions may now
+        // contain volatile functions (e.g. `random() AS col`). Duplicating
+        // volatile expressions in the pushed Sort would produce different
+        // values, changing results.
+        if resolved_sort_exprs.iter().any(|se| se.expr.is_volatile()) {
+            return Ok(Transformed::no(plan));
+        }
+
+        // Pick the first preserved-side candidate whose schema contains all
+        // referenced sort columns. For LEFT/RIGHT this is the fixed side;
+        // for CROSS we try both.
+        let Some(preserved_side) = 
preserved_candidates.iter().copied().find(|&side| {
+            let schema = match side {
+                Side::Left => join.left.schema(),
+                Side::Right => join.right.schema(),
+            };
+            let cols = schema_columns(schema);
+            resolved_sort_exprs
+                .iter()
+                .all(|se| has_all_column_refs(&se.expr, &cols))
+        }) else {
+            return Ok(Transformed::no(plan));
+        };
+
+        let preserved_child = match preserved_side {
+            Side::Left => &join.left,
+            Side::Right => &join.right,
+        };
+
+        // Scan deep inside the preserved child (through SubqueryAlias and
+        // Projection layers) to find an existing Sort. If found with same
+        // exprs, tighten its fetch in-place. Otherwise, insert a new Sort
+        // directly below the join as the preserved child's wrapper.
+        let mut inner_child = preserved_child.as_ref();
+        let mut deep_resolved_exprs = resolved_sort_exprs.clone();
+        loop {
+            match inner_child {
+                LogicalPlan::SubqueryAlias(sq) => {
+                    deep_resolved_exprs = 
resolve_sort_exprs_through_subquery_alias(
+                        &deep_resolved_exprs,
+                        sq,
+                    )?;
+                    inner_child = sq.input.as_ref();
+                }
+                LogicalPlan::Projection(proj) => {
+                    deep_resolved_exprs = 
resolve_sort_exprs_through_projection(
+                        &deep_resolved_exprs,
+                        proj,
+                    )?;
+                    inner_child = proj.input.as_ref();
+                }
+                _ => break,
+            }
+        }
+
+        // If the inner child is a Limit (PushDownLimit hasn't merged it with
+        // the Sort yet), skip this iteration.  PushDownLimit will merge
+        // Limit → Sort in the next pass, then our rule will tighten the Sort.
+        if matches!(inner_child, LogicalPlan::Limit(_)) {
+            return Ok(Transformed::no(plan));
+        }
+
+        // Determine action based on existing inner Sort:
+        // - Same exprs, tighter fetch → skip (already optimal)
+        // - Same exprs, larger/no fetch → tighten in-place
+        // - Different exprs or no Sort → insert new Sort below the join
+        //
+        // Example (tighten): Sort(a ASC, fetch=5) → Join → Sort(a ASC, 
fetch=10)
+        //   Child limits to 10, our tighter fetch=5 tightens it in-place.
+        //
+        // Example (tighten): Sort(a ASC, fetch=5) → Join → Sort(a ASC)
+        //   Child has no fetch (full sort), tighten to fetch=5.
+        //
+        // Example (skip): Sort(a ASC, fetch=5) → Join → Sort(a ASC, fetch=3)
+        //   Child already limits to 3 rows, pushing fetch=5 won't help.
+        //
+        // Example (new): Sort(b ASC, fetch=5) → Join → Sort(a ASC, fetch=10)
+        //   Different exprs, insert Sort(b, fetch=5) above preserved child.
+        let new_preserved_child = if let LogicalPlan::Sort(child_sort) = 
inner_child {
+            let same_exprs = sort_exprs_equal(&child_sort.expr, 
&deep_resolved_exprs);
+            let child_fetch_tighter = match child_sort.fetch {
+                Some(child_fetch) => child_fetch <= fetch,
+                None => false,
+            };
+            if same_exprs && child_fetch_tighter {
+                return Ok(Transformed::no(plan));
+            }
+            if same_exprs {
+                // Tighten existing Sort in-place by rebuilding the path

Review Comment:
   @gene-bordegaray thanks for highlighting this. resolved



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to