This is an automated email from the ASF dual-hosted git repository.

thinkharderdev pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-ballista.git


The following commit(s) were added to refs/heads/master by this push:
     new 0376f374 Add support for SortPreservingMergeExec; fix LIMIT bug (#304)
0376f374 is described below

commit 0376f37447d22e98134d8c7dc48ec68a2978f870
Author: Andy Grove <[email protected]>
AuthorDate: Sat Oct 1 13:15:27 2022 -0600

    Add support for SortPreservingMergeExec; fix LIMIT bug (#304)
    
    * Add support for SortPreservingMergeExec
    
    * Add support for SortExec::fetch
---
 ballista/rust/core/proto/ballista.proto           |  12 +-
 ballista/rust/core/src/serde/physical_plan/mod.rs | 130 +++++++++++++++++++---
 2 files changed, 126 insertions(+), 16 deletions(-)

diff --git a/ballista/rust/core/proto/ballista.proto 
b/ballista/rust/core/proto/ballista.proto
index a2b5f1fd..8923cbc9 100644
--- a/ballista/rust/core/proto/ballista.proto
+++ b/ballista/rust/core/proto/ballista.proto
@@ -79,6 +79,7 @@ message PhysicalPlanNode {
     PhysicalExtensionNode extension = 21;
     UnionExecNode union = 22;
     ExplainExecNode explain = 23;
+    SortPreservingMergeExecNode sort_preserving_merge = 24;
   }
 }
 
@@ -360,8 +361,10 @@ message ShuffleReaderPartition {
 
 message GlobalLimitExecNode {
   PhysicalPlanNode input = 1;
+  // The number of rows to skip before fetch
   uint32 skip = 2;
-  uint32 fetch = 3;
+  // Maximum number of rows to fetch; negative means no limit
+  int64 fetch = 3;
 }
 
 message LocalLimitExecNode {
@@ -372,6 +375,13 @@ message LocalLimitExecNode {
 message SortExecNode {
   PhysicalPlanNode input = 1;
   repeated PhysicalExprNode expr = 2;
+  // Maximum number of highest/lowest rows to fetch; negative means no limit
+  int64 fetch = 3;
+}
+
+message SortPreservingMergeExecNode {
+  PhysicalPlanNode input = 1;
+  repeated PhysicalExprNode expr = 2;
 }
 
 message CoalesceBatchesExecNode {
diff --git a/ballista/rust/core/src/serde/physical_plan/mod.rs 
b/ballista/rust/core/src/serde/physical_plan/mod.rs
index bfcc4127..d2c6b089 100644
--- a/ballista/rust/core/src/serde/physical_plan/mod.rs
+++ b/ballista/rust/core/src/serde/physical_plan/mod.rs
@@ -46,6 +46,7 @@ use datafusion::physical_plan::limit::{GlobalLimitExec, 
LocalLimitExec};
 use datafusion::physical_plan::projection::ProjectionExec;
 use datafusion::physical_plan::repartition::RepartitionExec;
 use datafusion::physical_plan::sorts::sort::SortExec;
+use 
datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
 use datafusion::physical_plan::union::UnionExec;
 use datafusion::physical_plan::windows::{create_window_expr, WindowAggExec};
 use datafusion::physical_plan::{
@@ -240,7 +241,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
             PhysicalPlanType::GlobalLimit(limit) => {
                 let input: Arc<dyn ExecutionPlan> =
                     into_physical_plan!(limit.input, registry, runtime, 
extension_codec)?;
-                let fetch = if limit.fetch > 0 {
+                let fetch = if limit.fetch >= 0 {
                     Some(limit.fetch as usize)
                 } else {
                     None
@@ -640,7 +641,53 @@ impl AsExecutionPlan for PhysicalPlanNode {
                         }
                     })
                     .collect::<Result<Vec<_>, _>>()?;
-                Ok(Arc::new(SortExec::try_new(exprs, input, None)?))
+                let fetch = if sort.fetch < 0 {
+                    None
+                } else {
+                    Some(sort.fetch as usize)
+                };
+                Ok(Arc::new(SortExec::try_new(exprs, input, fetch)?))
+            }
+            PhysicalPlanType::SortPreservingMerge(sort) => {
+                let input: Arc<dyn ExecutionPlan> =
+                    into_physical_plan!(sort.input, registry, runtime, 
extension_codec)?;
+                let exprs = sort
+                    .expr
+                    .iter()
+                    .map(|expr| {
+                        let expr = expr.expr_type.as_ref().ok_or_else(|| {
+                            proto_error(format!(
+                                "physical_plan::from_proto() Unexpected expr 
{:?}",
+                                self
+                            ))
+                        })?;
+                        if let 
protobuf::physical_expr_node::ExprType::Sort(sort_expr) = expr {
+                            let expr = sort_expr
+                                .expr
+                                .as_ref()
+                                .ok_or_else(|| {
+                                    proto_error(format!(
+                                        "physical_plan::from_proto() 
Unexpected sort expr {:?}",
+                                        self
+                                    ))
+                                })?
+                                .as_ref();
+                            Ok(PhysicalSortExpr {
+                                expr: parse_physical_expr(expr,registry, 
input.schema().as_ref())?,
+                                options: SortOptions {
+                                    descending: !sort_expr.asc,
+                                    nulls_first: sort_expr.nulls_first,
+                                },
+                            })
+                        } else {
+                            Err(BallistaError::General(format!(
+                                "physical_plan::from_proto() {:?}",
+                                self
+                            )))
+                        }
+                    })
+                    .collect::<Result<Vec<_>, _>>()?;
+                Ok(Arc::new(SortPreservingMergeExec::new(exprs, input)))
             }
             PhysicalPlanType::Unresolved(unresolved_shuffle) => {
                 let schema = 
Arc::new(convert_required!(unresolved_shuffle.schema)?);
@@ -739,7 +786,10 @@ impl AsExecutionPlan for PhysicalPlanNode {
                     protobuf::GlobalLimitExecNode {
                         input: Some(Box::new(input)),
                         skip: limit.skip() as u32,
-                        fetch: *limit.fetch().unwrap_or(&0) as u32,
+                        fetch: match limit.fetch() {
+                            Some(n) => *n as i64,
+                            _ => -1, // no limit
+                        },
                     },
                 ))),
             })
@@ -1059,6 +1109,10 @@ impl AsExecutionPlan for PhysicalPlanNode {
                     protobuf::SortExecNode {
                         input: Some(Box::new(input)),
                         expr,
+                        fetch: match exec.fetch() {
+                            Some(n) => n as i64,
+                            _ => -1,
+                        },
                     },
                 ))),
             })
@@ -1121,21 +1175,58 @@ impl AsExecutionPlan for PhysicalPlanNode {
                     protobuf::UnionExecNode { inputs },
                 )),
             })
-        } else {
-            let mut buf: Vec<u8> = vec![];
-            extension_codec.try_encode(plan_clone.clone(), &mut buf)?;
-
-            let inputs: Vec<PhysicalPlanNode> = plan_clone
-                .children()
-                .into_iter()
-                .map(|i| PhysicalPlanNode::try_from_physical_plan(i, 
extension_codec))
-                .collect::<Result<_, BallistaError>>()?;
-
+        } else if let Some(exec) = 
plan.downcast_ref::<SortPreservingMergeExec>() {
+            let input = protobuf::PhysicalPlanNode::try_from_physical_plan(
+                exec.input().to_owned(),
+                extension_codec,
+            )?;
+            let expr = exec
+                .expr()
+                .iter()
+                .map(|expr| {
+                    let sort_expr = Box::new(protobuf::PhysicalSortExprNode {
+                        expr: Some(Box::new(expr.expr.to_owned().try_into()?)),
+                        asc: !expr.options.descending,
+                        nulls_first: expr.options.nulls_first,
+                    });
+                    Ok(protobuf::PhysicalExprNode {
+                        expr_type: 
Some(protobuf::physical_expr_node::ExprType::Sort(
+                            sort_expr,
+                        )),
+                    })
+                })
+                .collect::<Result<Vec<_>, BallistaError>>()?;
             Ok(protobuf::PhysicalPlanNode {
-                physical_plan_type: Some(PhysicalPlanType::Extension(
-                    PhysicalExtensionNode { node: buf, inputs },
+                physical_plan_type: Some(PhysicalPlanType::SortPreservingMerge(
+                    Box::new(protobuf::SortPreservingMergeExecNode {
+                        input: Some(Box::new(input)),
+                        expr,
+                    }),
                 )),
             })
+        } else {
+            let mut buf: Vec<u8> = vec![];
+            match extension_codec.try_encode(plan_clone.clone(), &mut buf) {
+                Ok(_) => {
+                    let inputs: Vec<PhysicalPlanNode> = plan_clone
+                        .children()
+                        .into_iter()
+                        .map(|i| {
+                            PhysicalPlanNode::try_from_physical_plan(i, 
extension_codec)
+                        })
+                        .collect::<Result<_, BallistaError>>()?;
+
+                    Ok(protobuf::PhysicalPlanNode {
+                        physical_plan_type: Some(PhysicalPlanType::Extension(
+                            PhysicalExtensionNode { node: buf, inputs },
+                        )),
+                    })
+                }
+                Err(e) => Err(BallistaError::Internal(format!(
+                    "Unsupported plan and extension codec failed with [{}]. 
Plan: {:?}",
+                    e, plan_clone
+                ))),
+            }
         }
     }
 }
@@ -1339,6 +1430,15 @@ mod roundtrip_tests {
         )))
     }
 
+    #[test]
+    fn roundtrip_global_skip_no_limit() -> Result<()> {
+        roundtrip_test(Arc::new(GlobalLimitExec::new(
+            Arc::new(EmptyExec::new(false, Arc::new(Schema::empty()))),
+            10,
+            None, // no limit
+        )))
+    }
+
     #[test]
     fn roundtrip_hash_join() -> Result<()> {
         let field_a = Field::new("col", DataType::Int64, false);

Reply via email to