This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 98647e842a Stop copying LogicalPlan and Exprs in `PushDownLimit` 
(#10508)
98647e842a is described below

commit 98647e842a85b768ea0cb0f8ccf1016636001abb
Author: Andrew Lamb <[email protected]>
AuthorDate: Fri May 17 04:14:32 2024 -0400

    Stop copying LogicalPlan and Exprs in `PushDownLimit` (#10508)
    
    * Stop copying LogicalPlan and Exprs in `PushDownLimit`
    
    * Refine make_limit
---
 datafusion/optimizer/src/push_down_limit.rs | 275 +++++++++++++++-------------
 1 file changed, 149 insertions(+), 126 deletions(-)

diff --git a/datafusion/optimizer/src/push_down_limit.rs 
b/datafusion/optimizer/src/push_down_limit.rs
index 9190881335..b97dff74d9 100644
--- a/datafusion/optimizer/src/push_down_limit.rs
+++ b/datafusion/optimizer/src/push_down_limit.rs
@@ -23,11 +23,10 @@ use std::sync::Arc;
 use crate::optimizer::ApplyOrder;
 use crate::{OptimizerConfig, OptimizerRule};
 
-use datafusion_common::Result;
-use datafusion_expr::logical_plan::{
-    Join, JoinType, Limit, LogicalPlan, Sort, TableScan, Union,
-};
-use datafusion_expr::CrossJoin;
+use datafusion_common::tree_node::Transformed;
+use datafusion_common::{internal_err, Result};
+use datafusion_expr::logical_plan::tree_node::unwrap_arc;
+use datafusion_expr::logical_plan::{Join, JoinType, Limit, LogicalPlan};
 
 /// Optimization rule that tries to push down `LIMIT`.
 ///
@@ -46,131 +45,120 @@ impl PushDownLimit {
 impl OptimizerRule for PushDownLimit {
     fn try_optimize(
         &self,
-        plan: &LogicalPlan,
+        _plan: &LogicalPlan,
         _config: &dyn OptimizerConfig,
     ) -> Result<Option<LogicalPlan>> {
-        use std::cmp::min;
+        internal_err!("Should have called PushDownLimit::rewrite")
+    }
+
+    fn supports_rewrite(&self) -> bool {
+        true
+    }
 
-        let LogicalPlan::Limit(limit) = plan else {
-            return Ok(None);
+    fn rewrite(
+        &self,
+        plan: LogicalPlan,
+        _config: &dyn OptimizerConfig,
+    ) -> Result<Transformed<LogicalPlan>> {
+        let LogicalPlan::Limit(mut limit) = plan else {
+            return Ok(Transformed::no(plan));
         };
 
-        if let LogicalPlan::Limit(child) = &*limit.input {
-            // Merge the Parent Limit and the Child Limit.
+        let Limit { skip, fetch, input } = limit;
+        let input = input;
+
+        // Merge the Parent Limit and the Child Limit.
+        if let LogicalPlan::Limit(child) = input.as_ref() {
             let (skip, fetch) =
                 combine_limit(limit.skip, limit.fetch, child.skip, 
child.fetch);
 
             let plan = LogicalPlan::Limit(Limit {
                 skip,
                 fetch,
-                input: Arc::new((*child.input).clone()),
+                input: Arc::clone(&child.input),
             });
-            return self
-                .try_optimize(&plan, _config)
-                .map(|opt_plan| opt_plan.or_else(|| Some(plan)));
+
+            // recursively reapply the rule on the new plan
+            return self.rewrite(plan, _config);
         }
 
-        let Some(fetch) = limit.fetch else {
-            return Ok(None);
+        // no fetch to push, so return the original plan
+        let Some(fetch) = fetch else {
+            return Ok(Transformed::no(LogicalPlan::Limit(Limit {
+                skip,
+                fetch,
+                input,
+            })));
         };
-        let skip = limit.skip;
 
-        match limit.input.as_ref() {
-            LogicalPlan::TableScan(scan) => {
-                let limit = if fetch != 0 { fetch + skip } else { 0 };
-                let new_fetch = scan.fetch.map(|x| min(x, 
limit)).or(Some(limit));
+        match unwrap_arc(input) {
+            LogicalPlan::TableScan(mut scan) => {
+                let rows_needed = if fetch != 0 { fetch + skip } else { 0 };
+                let new_fetch = scan
+                    .fetch
+                    .map(|x| min(x, rows_needed))
+                    .or(Some(rows_needed));
                 if new_fetch == scan.fetch {
-                    Ok(None)
+                    original_limit(skip, fetch, LogicalPlan::TableScan(scan))
                 } else {
-                    let new_input = LogicalPlan::TableScan(TableScan {
-                        table_name: scan.table_name.clone(),
-                        source: scan.source.clone(),
-                        projection: scan.projection.clone(),
-                        filters: scan.filters.clone(),
-                        fetch: scan.fetch.map(|x| min(x, 
limit)).or(Some(limit)),
-                        projected_schema: scan.projected_schema.clone(),
-                    });
-                    plan.with_new_exprs(plan.expressions(), vec![new_input])
-                        .map(Some)
+                    // push limit into the table scan itself
+                    scan.fetch = scan
+                        .fetch
+                        .map(|x| min(x, rows_needed))
+                        .or(Some(rows_needed));
+                    transformed_limit(skip, fetch, 
LogicalPlan::TableScan(scan))
                 }
             }
-            LogicalPlan::Union(union) => {
-                let new_inputs = union
+            LogicalPlan::Union(mut union) => {
+                // push limits to each input of the union
+                union.inputs = union
                     .inputs
-                    .iter()
-                    .map(|x| {
-                        Ok(Arc::new(LogicalPlan::Limit(Limit {
-                            skip: 0,
-                            fetch: Some(fetch + skip),
-                            input: x.clone(),
-                        })))
-                    })
-                    .collect::<Result<_>>()?;
-                let union = LogicalPlan::Union(Union {
-                    inputs: new_inputs,
-                    schema: union.schema.clone(),
-                });
-                plan.with_new_exprs(plan.expressions(), vec![union])
-                    .map(Some)
+                    .into_iter()
+                    .map(|input| make_arc_limit(0, fetch + skip, input))
+                    .collect();
+                transformed_limit(skip, fetch, LogicalPlan::Union(union))
             }
 
-            LogicalPlan::CrossJoin(cross_join) => {
-                let new_left = LogicalPlan::Limit(Limit {
-                    skip: 0,
-                    fetch: Some(fetch + skip),
-                    input: cross_join.left.clone(),
-                });
-                let new_right = LogicalPlan::Limit(Limit {
-                    skip: 0,
-                    fetch: Some(fetch + skip),
-                    input: cross_join.right.clone(),
-                });
-                let new_cross_join = LogicalPlan::CrossJoin(CrossJoin {
-                    left: Arc::new(new_left),
-                    right: Arc::new(new_right),
-                    schema: plan.schema().clone(),
-                });
-                plan.with_new_exprs(plan.expressions(), vec![new_cross_join])
-                    .map(Some)
+            LogicalPlan::CrossJoin(mut cross_join) => {
+                // push limit to both inputs
+                cross_join.left = make_arc_limit(0, fetch + skip, 
cross_join.left);
+                cross_join.right = make_arc_limit(0, fetch + skip, 
cross_join.right);
+                transformed_limit(skip, fetch, 
LogicalPlan::CrossJoin(cross_join))
             }
 
-            LogicalPlan::Join(join) => {
-                if let Some(new_join) = push_down_join(join, fetch + skip) {
-                    let inputs = vec![LogicalPlan::Join(new_join)];
-                    plan.with_new_exprs(plan.expressions(), inputs).map(Some)
-                } else {
-                    Ok(None)
-                }
-            }
+            LogicalPlan::Join(join) => Ok(push_down_join(join, fetch + skip)
+                .update_data(|join| {
+                    make_limit(skip, fetch, Arc::new(LogicalPlan::Join(join)))
+                })),
 
-            LogicalPlan::Sort(sort) => {
+            LogicalPlan::Sort(mut sort) => {
                 let new_fetch = {
                     let sort_fetch = skip + fetch;
                     Some(sort.fetch.map(|f| 
f.min(sort_fetch)).unwrap_or(sort_fetch))
                 };
                 if new_fetch == sort.fetch {
-                    Ok(None)
+                    original_limit(skip, fetch, LogicalPlan::Sort(sort))
                 } else {
-                    let new_sort = LogicalPlan::Sort(Sort {
-                        expr: sort.expr.clone(),
-                        input: sort.input.clone(),
-                        fetch: new_fetch,
-                    });
-                    plan.with_new_exprs(plan.expressions(), vec![new_sort])
-                        .map(Some)
+                    sort.fetch = new_fetch;
+                    limit.input = Arc::new(LogicalPlan::Sort(sort));
+                    Ok(Transformed::yes(LogicalPlan::Limit(limit)))
                 }
             }
-            child_plan @ (LogicalPlan::Projection(_) | 
LogicalPlan::SubqueryAlias(_)) => {
+            LogicalPlan::Projection(mut proj) => {
                 // commute
-                let new_limit = plan.with_new_exprs(
-                    plan.expressions(),
-                    vec![child_plan.inputs()[0].clone()],
-                )?;
-                child_plan
-                    .with_new_exprs(child_plan.expressions(), vec![new_limit])
-                    .map(Some)
+                limit.input = Arc::clone(&proj.input);
+                let new_limit = LogicalPlan::Limit(limit);
+                proj.input = Arc::new(new_limit);
+                Ok(Transformed::yes(LogicalPlan::Projection(proj)))
             }
-            _ => Ok(None),
+            LogicalPlan::SubqueryAlias(mut subquery_alias) => {
+                // commute
+                limit.input = Arc::clone(&subquery_alias.input);
+                let new_limit = LogicalPlan::Limit(limit);
+                subquery_alias.input = Arc::new(new_limit);
+                
Ok(Transformed::yes(LogicalPlan::SubqueryAlias(subquery_alias)))
+            }
+            input => original_limit(skip, fetch, input),
         }
     }
 
@@ -183,6 +171,61 @@ impl OptimizerRule for PushDownLimit {
     }
 }
 
+/// Wrap the input plan with a limit node
+///
+/// Original:
+/// ```text
+/// input
+/// ```
+///
+/// Return
+/// ```text
+/// Limit: skip=skip, fetch=fetch
+///  input
+/// ```
+fn make_limit(skip: usize, fetch: usize, input: Arc<LogicalPlan>) -> 
LogicalPlan {
+    LogicalPlan::Limit(Limit {
+        skip,
+        fetch: Some(fetch),
+        input,
+    })
+}
+
+/// Wrap the input plan with a limit node
+fn make_arc_limit(
+    skip: usize,
+    fetch: usize,
+    input: Arc<LogicalPlan>,
+) -> Arc<LogicalPlan> {
+    Arc::new(make_limit(skip, fetch, input))
+}
+
+/// Returns the original limit (non transformed)
+fn original_limit(
+    skip: usize,
+    fetch: usize,
+    input: LogicalPlan,
+) -> Result<Transformed<LogicalPlan>> {
+    Ok(Transformed::no(LogicalPlan::Limit(Limit {
+        skip,
+        fetch: Some(fetch),
+        input: Arc::new(input),
+    })))
+}
+
+/// Returns the a transformed limit
+fn transformed_limit(
+    skip: usize,
+    fetch: usize,
+    input: LogicalPlan,
+) -> Result<Transformed<LogicalPlan>> {
+    Ok(Transformed::yes(LogicalPlan::Limit(Limit {
+        skip,
+        fetch: Some(fetch),
+        input: Arc::new(input),
+    })))
+}
+
 /// Combines two limits into a single
 ///
 /// Returns the combined limit `(skip, fetch)`
@@ -255,14 +298,15 @@ fn combine_limit(
     (combined_skip, combined_fetch)
 }
 
-fn push_down_join(join: &Join, limit: usize) -> Option<Join> {
+/// Adds a limit to the inputs of a join, if possible
+fn push_down_join(mut join: Join, limit: usize) -> Transformed<Join> {
     use JoinType::*;
 
     fn is_no_join_condition(join: &Join) -> bool {
         join.on.is_empty() && join.filter.is_none()
     }
 
-    let (left_limit, right_limit) = if is_no_join_condition(join) {
+    let (left_limit, right_limit) = if is_no_join_condition(&join) {
         match join.join_type {
             Left | Right | Full => (Some(limit), Some(limit)),
             LeftAnti | LeftSemi => (Some(limit), None),
@@ -277,37 +321,16 @@ fn push_down_join(join: &Join, limit: usize) -> 
Option<Join> {
         }
     };
 
-    match (left_limit, right_limit) {
-        (None, None) => None,
-        _ => {
-            let left = match left_limit {
-                Some(limit) => Arc::new(LogicalPlan::Limit(Limit {
-                    skip: 0,
-                    fetch: Some(limit),
-                    input: join.left.clone(),
-                })),
-                None => join.left.clone(),
-            };
-            let right = match right_limit {
-                Some(limit) => Arc::new(LogicalPlan::Limit(Limit {
-                    skip: 0,
-                    fetch: Some(limit),
-                    input: join.right.clone(),
-                })),
-                None => join.right.clone(),
-            };
-            Some(Join {
-                left,
-                right,
-                on: join.on.clone(),
-                filter: join.filter.clone(),
-                join_type: join.join_type,
-                join_constraint: join.join_constraint,
-                schema: join.schema.clone(),
-                null_equals_null: join.null_equals_null,
-            })
-        }
+    if left_limit.is_none() && right_limit.is_none() {
+        return Transformed::no(join);
+    }
+    if let Some(limit) = left_limit {
+        join.left = make_arc_limit(0, limit, join.left);
+    }
+    if let Some(limit) = right_limit {
+        join.right = make_arc_limit(0, limit, join.right);
     }
+    Transformed::yes(join)
 }
 
 #[cfg(test)]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to