alamb commented on code in PR #10508:
URL: https://github.com/apache/datafusion/pull/10508#discussion_r1600350491


##########
datafusion/optimizer/src/push_down_limit.rs:
##########
@@ -45,166 +45,125 @@ impl PushDownLimit {
 impl OptimizerRule for PushDownLimit {
     fn try_optimize(
         &self,
-        plan: &LogicalPlan,
+        _plan: &LogicalPlan,
         _config: &dyn OptimizerConfig,
     ) -> Result<Option<LogicalPlan>> {
-        use std::cmp::min;
+        internal_err!("Should have called PushDownLimit::rewrite")
+    }
+
+    fn supports_rewrite(&self) -> bool {
+        true
+    }
 
-        let LogicalPlan::Limit(limit) = plan else {
-            return Ok(None);
+    fn rewrite(
+        &self,
+        plan: LogicalPlan,
+        _config: &dyn OptimizerConfig,
+    ) -> Result<Transformed<LogicalPlan>> {
+        let LogicalPlan::Limit(mut limit) = plan else {
+            return Ok(Transformed::no(plan));
         };
 
-        if let LogicalPlan::Limit(child) = &*limit.input {
-            // Merge the Parent Limit and the Child Limit.
-
-            //  Case 0: Parent and Child are disjoint. (child_fetch <= skip)
-            //   Before merging:
-            //                     |........skip........|---fetch-->|          
    Parent Limit
-            //    |...child_skip...|---child_fetch-->|                         
    Child Limit
-            //   After merging:
-            //    |.........(child_skip + skip).........|
-            //   Before merging:
-            //                     |...skip...|------------fetch------------>| 
    Parent Limit
-            //    |...child_skip...|-------------child_fetch------------>|     
    Child Limit
-            //   After merging:
-            //    |....(child_skip + skip)....|---(child_fetch - skip)-->|
-
-            //  Case 1: Parent is beyond the range of Child. (skip < 
child_fetch <= skip + fetch)
-            //   Before merging:
-            //                     |...skip...|------------fetch------------>| 
    Parent Limit
-            //    |...child_skip...|-------------child_fetch------------>|     
    Child Limit
-            //   After merging:
-            //    |....(child_skip + skip)....|---(child_fetch - skip)-->|
-
-            //  Case 2: Parent is in the range of Child. (skip + fetch < 
child_fetch)
-            //   Before merging:
-            //                     |...skip...|---fetch-->|                    
    Parent Limit
-            //    |...child_skip...|-------------child_fetch------------>|     
    Child Limit
-            //   After merging:
-            //    |....(child_skip + skip)....|---fetch-->|
-            let parent_skip = limit.skip;
-            let new_fetch = match (limit.fetch, child.fetch) {
-                (Some(fetch), Some(child_fetch)) => {
-                    Some(min(fetch, child_fetch.saturating_sub(parent_skip)))
-                }
-                (Some(fetch), None) => Some(fetch),
-                (None, Some(child_fetch)) => {
-                    Some(child_fetch.saturating_sub(parent_skip))
-                }
-                (None, None) => None,
-            };
+        let Limit { skip, fetch, input } = limit;
+        let input = input;
+
+        // Merge the Parent Limit and the Child Limit.
+        if let LogicalPlan::Limit(child) = input.as_ref() {
+            let (skip, fetch) =
+                combine_limit(limit.skip, limit.fetch, child.skip, 
child.fetch);
 
             let plan = LogicalPlan::Limit(Limit {
-                skip: child.skip + parent_skip,
-                fetch: new_fetch,
-                input: Arc::new((*child.input).clone()),
+                skip,
+                fetch,
+                input: Arc::clone(&child.input),
             });
-            return self
-                .try_optimize(&plan, _config)
-                .map(|opt_plan| opt_plan.or_else(|| Some(plan)));
+
+            // recursively reapply the rule on the new plan
+            return self.rewrite(plan, _config);
         }
 
-        let Some(fetch) = limit.fetch else {
-            return Ok(None);
+        // no fetch to push, so return the original plan
+        let Some(fetch) = fetch else {
+            return Ok(Transformed::no(LogicalPlan::Limit(Limit {
+                skip,
+                fetch,
+                input,
+            })));
         };
-        let skip = limit.skip;
 
-        match limit.input.as_ref() {
-            LogicalPlan::TableScan(scan) => {
-                let limit = if fetch != 0 { fetch + skip } else { 0 };
-                let new_fetch = scan.fetch.map(|x| min(x, 
limit)).or(Some(limit));
+        match unwrap_arc(input) {
+            LogicalPlan::TableScan(mut scan) => {
+                let rows_needed = if fetch != 0 { fetch + skip } else { 0 };
+                let new_fetch = scan
+                    .fetch
+                    .map(|x| min(x, rows_needed))
+                    .or(Some(rows_needed));
                 if new_fetch == scan.fetch {
-                    Ok(None)
+                    let input = Arc::new(LogicalPlan::TableScan(scan));
+                    original_limit(skip, fetch, input)
                 } else {
-                    let new_input = LogicalPlan::TableScan(TableScan {
-                        table_name: scan.table_name.clone(),
-                        source: scan.source.clone(),
-                        projection: scan.projection.clone(),
-                        filters: scan.filters.clone(),
-                        fetch: scan.fetch.map(|x| min(x, 
limit)).or(Some(limit)),
-                        projected_schema: scan.projected_schema.clone(),
-                    });
-                    plan.with_new_exprs(plan.expressions(), vec![new_input])
-                        .map(Some)
+                    // push limit into the table scan itself
+                    scan.fetch = scan
+                        .fetch
+                        .map(|x| min(x, rows_needed))
+                        .or(Some(rows_needed));
+                    let input = Arc::new(LogicalPlan::TableScan(scan));
+                    transformed_limit(skip, fetch, input)
                 }
             }
-            LogicalPlan::Union(union) => {
-                let new_inputs = union
+            LogicalPlan::Union(mut union) => {
+                // push limits to each input of the union
+                union.inputs = union
                     .inputs
-                    .iter()
-                    .map(|x| {
-                        Ok(Arc::new(LogicalPlan::Limit(Limit {
-                            skip: 0,
-                            fetch: Some(fetch + skip),
-                            input: x.clone(),
-                        })))
-                    })
-                    .collect::<Result<_>>()?;
-                let union = LogicalPlan::Union(Union {
-                    inputs: new_inputs,
-                    schema: union.schema.clone(),
-                });
-                plan.with_new_exprs(plan.expressions(), vec![union])
-                    .map(Some)
+                    .into_iter()
+                    .map(|input| make_arc_limit(0, fetch + skip, input))

Review Comment:
   I also moved some of the boiler plate for creating `Limit` into their own 
functions



##########
datafusion/optimizer/src/push_down_limit.rs:
##########
@@ -45,166 +45,125 @@ impl PushDownLimit {
 impl OptimizerRule for PushDownLimit {
     fn try_optimize(
         &self,
-        plan: &LogicalPlan,
+        _plan: &LogicalPlan,
         _config: &dyn OptimizerConfig,
     ) -> Result<Option<LogicalPlan>> {
-        use std::cmp::min;
+        internal_err!("Should have called PushDownLimit::rewrite")
+    }
+
+    fn supports_rewrite(&self) -> bool {
+        true
+    }
 
-        let LogicalPlan::Limit(limit) = plan else {
-            return Ok(None);
+    fn rewrite(
+        &self,
+        plan: LogicalPlan,
+        _config: &dyn OptimizerConfig,
+    ) -> Result<Transformed<LogicalPlan>> {
+        let LogicalPlan::Limit(mut limit) = plan else {
+            return Ok(Transformed::no(plan));
         };
 
-        if let LogicalPlan::Limit(child) = &*limit.input {
-            // Merge the Parent Limit and the Child Limit.
-
-            //  Case 0: Parent and Child are disjoint. (child_fetch <= skip)
-            //   Before merging:
-            //                     |........skip........|---fetch-->|          
    Parent Limit
-            //    |...child_skip...|---child_fetch-->|                         
    Child Limit
-            //   After merging:
-            //    |.........(child_skip + skip).........|
-            //   Before merging:
-            //                     |...skip...|------------fetch------------>| 
    Parent Limit
-            //    |...child_skip...|-------------child_fetch------------>|     
    Child Limit
-            //   After merging:
-            //    |....(child_skip + skip)....|---(child_fetch - skip)-->|
-
-            //  Case 1: Parent is beyond the range of Child. (skip < 
child_fetch <= skip + fetch)
-            //   Before merging:
-            //                     |...skip...|------------fetch------------>| 
    Parent Limit
-            //    |...child_skip...|-------------child_fetch------------>|     
    Child Limit
-            //   After merging:
-            //    |....(child_skip + skip)....|---(child_fetch - skip)-->|
-
-            //  Case 2: Parent is in the range of Child. (skip + fetch < 
child_fetch)
-            //   Before merging:
-            //                     |...skip...|---fetch-->|                    
    Parent Limit
-            //    |...child_skip...|-------------child_fetch------------>|     
    Child Limit
-            //   After merging:
-            //    |....(child_skip + skip)....|---fetch-->|
-            let parent_skip = limit.skip;
-            let new_fetch = match (limit.fetch, child.fetch) {
-                (Some(fetch), Some(child_fetch)) => {
-                    Some(min(fetch, child_fetch.saturating_sub(parent_skip)))
-                }
-                (Some(fetch), None) => Some(fetch),
-                (None, Some(child_fetch)) => {
-                    Some(child_fetch.saturating_sub(parent_skip))
-                }
-                (None, None) => None,
-            };
+        let Limit { skip, fetch, input } = limit;
+        let input = input;
+
+        // Merge the Parent Limit and the Child Limit.
+        if let LogicalPlan::Limit(child) = input.as_ref() {
+            let (skip, fetch) =
+                combine_limit(limit.skip, limit.fetch, child.skip, 
child.fetch);
 
             let plan = LogicalPlan::Limit(Limit {
-                skip: child.skip + parent_skip,
-                fetch: new_fetch,
-                input: Arc::new((*child.input).clone()),
+                skip,
+                fetch,
+                input: Arc::clone(&child.input),
             });
-            return self
-                .try_optimize(&plan, _config)
-                .map(|opt_plan| opt_plan.or_else(|| Some(plan)));
+
+            // recursively reapply the rule on the new plan
+            return self.rewrite(plan, _config);
         }
 
-        let Some(fetch) = limit.fetch else {
-            return Ok(None);
+        // no fetch to push, so return the original plan
+        let Some(fetch) = fetch else {
+            return Ok(Transformed::no(LogicalPlan::Limit(Limit {
+                skip,
+                fetch,
+                input,
+            })));
         };
-        let skip = limit.skip;
 
-        match limit.input.as_ref() {
-            LogicalPlan::TableScan(scan) => {
-                let limit = if fetch != 0 { fetch + skip } else { 0 };
-                let new_fetch = scan.fetch.map(|x| min(x, 
limit)).or(Some(limit));
+        match unwrap_arc(input) {
+            LogicalPlan::TableScan(mut scan) => {
+                let rows_needed = if fetch != 0 { fetch + skip } else { 0 };
+                let new_fetch = scan
+                    .fetch
+                    .map(|x| min(x, rows_needed))
+                    .or(Some(rows_needed));
                 if new_fetch == scan.fetch {
-                    Ok(None)
+                    let input = Arc::new(LogicalPlan::TableScan(scan));
+                    original_limit(skip, fetch, input)
                 } else {
-                    let new_input = LogicalPlan::TableScan(TableScan {
-                        table_name: scan.table_name.clone(),
-                        source: scan.source.clone(),
-                        projection: scan.projection.clone(),
-                        filters: scan.filters.clone(),
-                        fetch: scan.fetch.map(|x| min(x, 
limit)).or(Some(limit)),
-                        projected_schema: scan.projected_schema.clone(),
-                    });
-                    plan.with_new_exprs(plan.expressions(), vec![new_input])

Review Comment:
   plan_with_new_exprs copies expressions in addition to all the other clones 
above, so this removes non trivial number of clones



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org
For additional commands, e-mail: github-h...@datafusion.apache.org

Reply via email to