This is an automated email from the ASF dual-hosted git repository.
dheres pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 727b6ff415 Add fetch to `SortPreservingMergeExec` and
`SortPreservingMergeStream` (#6811)
727b6ff415 is described below
commit 727b6ff41502e276ed7885531a87364a71826a74
Author: Daniël Heres <[email protected]>
AuthorDate: Mon Jul 3 12:29:32 2023 +0200
Add fetch to `SortPreservingMergeExec` and `SortPreservingMergeStream`
(#6811)
* Add fetch to sortpreservingmergeexec
* Add fetch to sortpreservingmergeexec
* fmt
* Deserialize
* Fmt
* Fix test
* Fix test
* Fix test
* Fix plan output
* Doc
* Update datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
Co-authored-by: Andrew Lamb <[email protected]>
* Extract into method
* Remove from sort enforcement
* Update datafusion/core/src/physical_plan/sorts/merge.rs
Co-authored-by: Mustafa Akur
<[email protected]>
* Update datafusion/proto/src/physical_plan/mod.rs
Co-authored-by: Mustafa Akur
<[email protected]>
---------
Co-authored-by: Daniël Heres <[email protected]>
Co-authored-by: Andrew Lamb <[email protected]>
Co-authored-by: Mustafa Akur
<[email protected]>
---
.../physical_optimizer/global_sort_selection.rs | 2 +-
.../core/src/physical_plan/repartition/mod.rs | 1 +
datafusion/core/src/physical_plan/sorts/merge.rs | 38 ++++++++++++++++++----
datafusion/core/src/physical_plan/sorts/sort.rs | 4 +--
.../physical_plan/sorts/sort_preserving_merge.rs | 30 ++++++++++++++---
datafusion/core/tests/sql/explain_analyze.rs | 2 +-
.../sqllogictests/test_files/tpch/q10.slt.part | 2 +-
.../sqllogictests/test_files/tpch/q11.slt.part | 2 +-
.../sqllogictests/test_files/tpch/q13.slt.part | 2 +-
.../sqllogictests/test_files/tpch/q16.slt.part | 2 +-
.../sqllogictests/test_files/tpch/q2.slt.part | 2 +-
.../sqllogictests/test_files/tpch/q3.slt.part | 2 +-
.../sqllogictests/test_files/tpch/q9.slt.part | 2 +-
.../core/tests/sqllogictests/test_files/union.slt | 2 +-
.../core/tests/sqllogictests/test_files/window.slt | 2 +-
datafusion/proto/proto/datafusion.proto | 2 ++
datafusion/proto/src/generated/pbjson.rs | 19 +++++++++++
datafusion/proto/src/generated/prost.rs | 3 ++
datafusion/proto/src/physical_plan/mod.rs | 10 +++++-
19 files changed, 103 insertions(+), 26 deletions(-)
diff --git a/datafusion/core/src/physical_optimizer/global_sort_selection.rs
b/datafusion/core/src/physical_optimizer/global_sort_selection.rs
index 9466297d24..0b9054f89f 100644
--- a/datafusion/core/src/physical_optimizer/global_sort_selection.rs
+++ b/datafusion/core/src/physical_optimizer/global_sort_selection.rs
@@ -70,7 +70,7 @@ impl PhysicalOptimizerRule for GlobalSortSelection {
Arc::new(SortPreservingMergeExec::new(
sort_exec.expr().to_vec(),
Arc::new(sort),
- ));
+ ).with_fetch(sort_exec.fetch()));
Some(global_sort)
} else {
None
diff --git a/datafusion/core/src/physical_plan/repartition/mod.rs
b/datafusion/core/src/physical_plan/repartition/mod.rs
index 72ff0c3713..85225eb471 100644
--- a/datafusion/core/src/physical_plan/repartition/mod.rs
+++ b/datafusion/core/src/physical_plan/repartition/mod.rs
@@ -497,6 +497,7 @@ impl ExecutionPlan for RepartitionExec {
sort_exprs,
BaselineMetrics::new(&self.metrics, partition),
context.session_config().batch_size(),
+ None,
)
} else {
Ok(Box::pin(RepartitionStream {
diff --git a/datafusion/core/src/physical_plan/sorts/merge.rs
b/datafusion/core/src/physical_plan/sorts/merge.rs
index d8a3cdef4d..e191c044b9 100644
--- a/datafusion/core/src/physical_plan/sorts/merge.rs
+++ b/datafusion/core/src/physical_plan/sorts/merge.rs
@@ -39,13 +39,14 @@ macro_rules! primitive_merge_helper {
}
macro_rules! merge_helper {
- ($t:ty, $sort:ident, $streams:ident, $schema:ident,
$tracking_metrics:ident, $batch_size:ident) => {{
+ ($t:ty, $sort:ident, $streams:ident, $schema:ident,
$tracking_metrics:ident, $batch_size:ident, $fetch:ident) => {{
let streams = FieldCursorStream::<$t>::new($sort, $streams);
return Ok(Box::pin(SortPreservingMergeStream::new(
Box::new(streams),
$schema,
$tracking_metrics,
$batch_size,
+ $fetch,
)));
}};
}
@@ -57,17 +58,18 @@ pub(crate) fn streaming_merge(
expressions: &[PhysicalSortExpr],
metrics: BaselineMetrics,
batch_size: usize,
+ fetch: Option<usize>,
) -> Result<SendableRecordBatchStream> {
// Special case single column comparisons with optimized cursor
implementations
if expressions.len() == 1 {
let sort = expressions[0].clone();
let data_type = sort.expr.data_type(schema.as_ref())?;
downcast_primitive! {
- data_type => (primitive_merge_helper, sort, streams, schema,
metrics, batch_size),
- DataType::Utf8 => merge_helper!(StringArray, sort, streams,
schema, metrics, batch_size)
- DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort,
streams, schema, metrics, batch_size)
- DataType::Binary => merge_helper!(BinaryArray, sort, streams,
schema, metrics, batch_size)
- DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort,
streams, schema, metrics, batch_size)
+ data_type => (primitive_merge_helper, sort, streams, schema,
metrics, batch_size, fetch),
+ DataType::Utf8 => merge_helper!(StringArray, sort, streams,
schema, metrics, batch_size, fetch)
+ DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort,
streams, schema, metrics, batch_size, fetch)
+ DataType::Binary => merge_helper!(BinaryArray, sort, streams,
schema, metrics, batch_size, fetch)
+ DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort,
streams, schema, metrics, batch_size, fetch)
_ => {}
}
}
@@ -78,6 +80,7 @@ pub(crate) fn streaming_merge(
schema,
metrics,
batch_size,
+ fetch,
)))
}
@@ -140,6 +143,12 @@ struct SortPreservingMergeStream<C> {
/// Vector that holds cursors for each non-exhausted input partition
cursors: Vec<Option<C>>,
+
+ /// Optional number of rows to fetch
+ fetch: Option<usize>,
+
+ /// number of rows produced
+ produced: usize,
}
impl<C: Cursor> SortPreservingMergeStream<C> {
@@ -148,6 +157,7 @@ impl<C: Cursor> SortPreservingMergeStream<C> {
schema: SchemaRef,
metrics: BaselineMetrics,
batch_size: usize,
+ fetch: Option<usize>,
) -> Self {
let stream_count = streams.partitions();
@@ -160,6 +170,8 @@ impl<C: Cursor> SortPreservingMergeStream<C> {
loser_tree: vec![],
loser_tree_adjusted: false,
batch_size,
+ fetch,
+ produced: 0,
}
}
@@ -227,15 +239,27 @@ impl<C: Cursor> SortPreservingMergeStream<C> {
if self.advance(stream_idx) {
self.loser_tree_adjusted = false;
self.in_progress.push_row(stream_idx);
- if self.in_progress.len() < self.batch_size {
+
+ // stop sorting if fetch has been reached
+ if self.fetch_reached() {
+ self.aborted = true;
+ } else if self.in_progress.len() < self.batch_size {
continue;
}
}
+ self.produced += self.in_progress.len();
+
return
Poll::Ready(self.in_progress.build_record_batch().transpose());
}
}
+ fn fetch_reached(&mut self) -> bool {
+ self.fetch
+ .map(|fetch| self.produced + self.in_progress.len() >= fetch)
+ .unwrap_or(false)
+ }
+
fn advance(&mut self, stream_idx: usize) -> bool {
let slot = &mut self.cursors[stream_idx];
match slot.as_mut() {
diff --git a/datafusion/core/src/physical_plan/sorts/sort.rs
b/datafusion/core/src/physical_plan/sorts/sort.rs
index 4983b0ea83..205ec706b5 100644
--- a/datafusion/core/src/physical_plan/sorts/sort.rs
+++ b/datafusion/core/src/physical_plan/sorts/sort.rs
@@ -189,6 +189,7 @@ impl ExternalSorter {
&self.expr,
self.metrics.baseline.clone(),
self.batch_size,
+ self.fetch,
)
} else if !self.in_mem_batches.is_empty() {
let result =
self.in_mem_sort_stream(self.metrics.baseline.clone());
@@ -285,14 +286,13 @@ impl ExternalSorter {
})
.collect::<Result<_>>()?;
- // TODO: Pushdown fetch to streaming merge (#6000)
-
streaming_merge(
streams,
self.schema.clone(),
&self.expr,
metrics,
self.batch_size,
+ self.fetch,
)
}
diff --git a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
index 4db1fea2a4..397d254162 100644
--- a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
+++ b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
@@ -71,6 +71,8 @@ pub struct SortPreservingMergeExec {
expr: Vec<PhysicalSortExpr>,
/// Execution metrics
metrics: ExecutionPlanMetricsSet,
+ /// Optional number of rows to fetch. Stops producing rows after this fetch
+ fetch: Option<usize>,
}
impl SortPreservingMergeExec {
@@ -80,8 +82,14 @@ impl SortPreservingMergeExec {
input,
expr,
metrics: ExecutionPlanMetricsSet::new(),
+ fetch: None,
}
}
+ /// Sets the number of rows to fetch
+ pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
+ self.fetch = fetch;
+ self
+ }
/// Input schema
pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
@@ -92,6 +100,11 @@ impl SortPreservingMergeExec {
pub fn expr(&self) -> &[PhysicalSortExpr] {
&self.expr
}
+
+ /// Fetch
+ pub fn fetch(&self) -> Option<usize> {
+ self.fetch
+ }
}
impl ExecutionPlan for SortPreservingMergeExec {
@@ -137,10 +150,10 @@ impl ExecutionPlan for SortPreservingMergeExec {
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
- Ok(Arc::new(SortPreservingMergeExec::new(
- self.expr.clone(),
- children[0].clone(),
- )))
+ Ok(Arc::new(
+ SortPreservingMergeExec::new(self.expr.clone(),
children[0].clone())
+ .with_fetch(self.fetch),
+ ))
}
fn execute(
@@ -192,6 +205,7 @@ impl ExecutionPlan for SortPreservingMergeExec {
&self.expr,
BaselineMetrics::new(&self.metrics, partition),
context.session_config().batch_size(),
+ self.fetch,
)?;
debug!("Got stream result from
SortPreservingMergeStream::new_from_receivers");
@@ -209,7 +223,12 @@ impl ExecutionPlan for SortPreservingMergeExec {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
let expr: Vec<String> = self.expr.iter().map(|e|
e.to_string()).collect();
- write!(f, "SortPreservingMergeExec: [{}]", expr.join(","))
+ write!(f, "SortPreservingMergeExec: [{}]", expr.join(","))?;
+ if let Some(fetch) = self.fetch {
+ write!(f, ", fetch={fetch}")?;
+ };
+
+ Ok(())
}
}
}
@@ -814,6 +833,7 @@ mod tests {
sort.as_slice(),
BaselineMetrics::new(&metrics, 0),
task_ctx.session_config().batch_size(),
+ None,
)
.unwrap();
diff --git a/datafusion/core/tests/sql/explain_analyze.rs
b/datafusion/core/tests/sql/explain_analyze.rs
index 01bdb629ee..e0130cb09c 100644
--- a/datafusion/core/tests/sql/explain_analyze.rs
+++ b/datafusion/core/tests/sql/explain_analyze.rs
@@ -599,7 +599,7 @@ async fn test_physical_plan_display_indent() {
let physical_plan = dataframe.create_physical_plan().await.unwrap();
let expected = vec![
"GlobalLimitExec: skip=0, fetch=10",
- " SortPreservingMergeExec: [the_min@2 DESC]",
+ " SortPreservingMergeExec: [the_min@2 DESC], fetch=10",
" SortExec: fetch=10, expr=[the_min@2 DESC]",
" ProjectionExec: expr=[c1@0 as c1, MAX(aggregate_test_100.c12)@1
as MAX(aggregate_test_100.c12), MIN(aggregate_test_100.c12)@2 as the_min]",
" AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1],
aggr=[MAX(aggregate_test_100.c12), MIN(aggregate_test_100.c12)]",
diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q10.slt.part
b/datafusion/core/tests/sqllogictests/test_files/tpch/q10.slt.part
index d2e06d5ff6..6c662c1091 100644
--- a/datafusion/core/tests/sqllogictests/test_files/tpch/q10.slt.part
+++ b/datafusion/core/tests/sqllogictests/test_files/tpch/q10.slt.part
@@ -71,7 +71,7 @@ Limit: skip=0, fetch=10
------------TableScan: nation projection=[n_nationkey, n_name]
physical_plan
GlobalLimitExec: skip=0, fetch=10
---SortPreservingMergeExec: [revenue@2 DESC]
+--SortPreservingMergeExec: [revenue@2 DESC], fetch=10
----SortExec: fetch=10, expr=[revenue@2 DESC]
------ProjectionExec: expr=[c_custkey@0 as c_custkey, c_name@1 as c_name,
SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@7 as revenue,
c_acctbal@2 as c_acctbal, n_name@4 as n_name, c_address@5 as c_address,
c_phone@3 as c_phone, c_comment@6 as c_comment]
--------AggregateExec: mode=FinalPartitioned, gby=[c_custkey@0 as c_custkey,
c_name@1 as c_name, c_acctbal@2 as c_acctbal, c_phone@3 as c_phone, n_name@4 as
n_name, c_address@5 as c_address, c_comment@6 as c_comment],
aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]
diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q11.slt.part
b/datafusion/core/tests/sqllogictests/test_files/tpch/q11.slt.part
index af29708c67..0c16fe1ab9 100644
--- a/datafusion/core/tests/sqllogictests/test_files/tpch/q11.slt.part
+++ b/datafusion/core/tests/sqllogictests/test_files/tpch/q11.slt.part
@@ -75,7 +75,7 @@ Limit: skip=0, fetch=10
----------------------TableScan: nation projection=[n_nationkey, n_name],
partial_filters=[nation.n_name = Utf8("GERMANY")]
physical_plan
GlobalLimitExec: skip=0, fetch=10
---SortPreservingMergeExec: [value@1 DESC]
+--SortPreservingMergeExec: [value@1 DESC], fetch=10
----SortExec: fetch=10, expr=[value@1 DESC]
------ProjectionExec: expr=[ps_partkey@0 as ps_partkey,
SUM(partsupp.ps_supplycost * partsupp.ps_availqty)@1 as value]
--------NestedLoopJoinExec: join_type=Inner,
filter=CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty)@0 AS
Decimal128(38, 15)) > SUM(partsupp.ps_supplycost * partsupp.ps_availqty) *
Float64(0.0001)@1
diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q13.slt.part
b/datafusion/core/tests/sqllogictests/test_files/tpch/q13.slt.part
index 7e5be14271..8ac9576a12 100644
--- a/datafusion/core/tests/sqllogictests/test_files/tpch/q13.slt.part
+++ b/datafusion/core/tests/sqllogictests/test_files/tpch/q13.slt.part
@@ -56,7 +56,7 @@ Limit: skip=0, fetch=10
------------------------TableScan: orders projection=[o_orderkey, o_custkey,
o_comment], partial_filters=[orders.o_comment NOT LIKE
Utf8("%special%requests%")]
physical_plan
GlobalLimitExec: skip=0, fetch=10
---SortPreservingMergeExec: [custdist@1 DESC,c_count@0 DESC]
+--SortPreservingMergeExec: [custdist@1 DESC,c_count@0 DESC], fetch=10
----SortExec: fetch=10, expr=[custdist@1 DESC,c_count@0 DESC]
------ProjectionExec: expr=[c_count@0 as c_count, COUNT(UInt8(1))@1 as
custdist]
--------AggregateExec: mode=FinalPartitioned, gby=[c_count@0 as c_count],
aggr=[COUNT(UInt8(1))]
diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q16.slt.part
b/datafusion/core/tests/sqllogictests/test_files/tpch/q16.slt.part
index 677db0329c..58796e93a8 100644
--- a/datafusion/core/tests/sqllogictests/test_files/tpch/q16.slt.part
+++ b/datafusion/core/tests/sqllogictests/test_files/tpch/q16.slt.part
@@ -67,7 +67,7 @@ Limit: skip=0, fetch=10
------------------TableScan: supplier projection=[s_suppkey, s_comment],
partial_filters=[supplier.s_comment LIKE Utf8("%Customer%Complaints%")]
physical_plan
GlobalLimitExec: skip=0, fetch=10
---SortPreservingMergeExec: [supplier_cnt@3 DESC,p_brand@0 ASC NULLS
LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST]
+--SortPreservingMergeExec: [supplier_cnt@3 DESC,p_brand@0 ASC NULLS
LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST], fetch=10
----SortExec: fetch=10, expr=[supplier_cnt@3 DESC,p_brand@0 ASC NULLS
LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST]
------ProjectionExec: expr=[group_alias_0@0 as part.p_brand, group_alias_1@1
as part.p_type, group_alias_2@2 as part.p_size, COUNT(alias1)@3 as supplier_cnt]
--------AggregateExec: mode=FinalPartitioned, gby=[group_alias_0@0 as
group_alias_0, group_alias_1@1 as group_alias_1, group_alias_2@2 as
group_alias_2], aggr=[COUNT(alias1)]
diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q2.slt.part
b/datafusion/core/tests/sqllogictests/test_files/tpch/q2.slt.part
index 4ad1ed7293..18cd261b76 100644
--- a/datafusion/core/tests/sqllogictests/test_files/tpch/q2.slt.part
+++ b/datafusion/core/tests/sqllogictests/test_files/tpch/q2.slt.part
@@ -101,7 +101,7 @@ Limit: skip=0, fetch=10
----------------------TableScan: region projection=[r_regionkey, r_name],
partial_filters=[region.r_name = Utf8("EUROPE")]
physical_plan
GlobalLimitExec: skip=0, fetch=10
---SortPreservingMergeExec: [s_acctbal@0 DESC,n_name@2 ASC NULLS LAST,s_name@1
ASC NULLS LAST,p_partkey@3 ASC NULLS LAST]
+--SortPreservingMergeExec: [s_acctbal@0 DESC,n_name@2 ASC NULLS LAST,s_name@1
ASC NULLS LAST,p_partkey@3 ASC NULLS LAST], fetch=10
----SortExec: fetch=10, expr=[s_acctbal@0 DESC,n_name@2 ASC NULLS
LAST,s_name@1 ASC NULLS LAST,p_partkey@3 ASC NULLS LAST]
------ProjectionExec: expr=[s_acctbal@5 as s_acctbal, s_name@2 as s_name,
n_name@8 as n_name, p_partkey@0 as p_partkey, p_mfgr@1 as p_mfgr, s_address@3
as s_address, s_phone@4 as s_phone, s_comment@6 as s_comment]
--------CoalesceBatchesExec: target_batch_size=8192
diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q3.slt.part
b/datafusion/core/tests/sqllogictests/test_files/tpch/q3.slt.part
index dc3b150877..f8c1385681 100644
--- a/datafusion/core/tests/sqllogictests/test_files/tpch/q3.slt.part
+++ b/datafusion/core/tests/sqllogictests/test_files/tpch/q3.slt.part
@@ -60,7 +60,7 @@ Limit: skip=0, fetch=10
----------------TableScan: lineitem projection=[l_orderkey, l_extendedprice,
l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate > Date32("9204")]
physical_plan
GlobalLimitExec: skip=0, fetch=10
---SortPreservingMergeExec: [revenue@1 DESC,o_orderdate@2 ASC NULLS LAST]
+--SortPreservingMergeExec: [revenue@1 DESC,o_orderdate@2 ASC NULLS LAST],
fetch=10
----SortExec: fetch=10, expr=[revenue@1 DESC,o_orderdate@2 ASC NULLS LAST]
------ProjectionExec: expr=[l_orderkey@0 as l_orderkey,
SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@3 as revenue,
o_orderdate@1 as o_orderdate, o_shippriority@2 as o_shippriority]
--------AggregateExec: mode=FinalPartitioned, gby=[l_orderkey@0 as l_orderkey,
o_orderdate@1 as o_orderdate, o_shippriority@2 as o_shippriority],
aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]
diff --git a/datafusion/core/tests/sqllogictests/test_files/tpch/q9.slt.part
b/datafusion/core/tests/sqllogictests/test_files/tpch/q9.slt.part
index 756b2e2c7c..45a4be6466 100644
--- a/datafusion/core/tests/sqllogictests/test_files/tpch/q9.slt.part
+++ b/datafusion/core/tests/sqllogictests/test_files/tpch/q9.slt.part
@@ -77,7 +77,7 @@ Limit: skip=0, fetch=10
--------------TableScan: nation projection=[n_nationkey, n_name]
physical_plan
GlobalLimitExec: skip=0, fetch=10
---SortPreservingMergeExec: [nation@0 ASC NULLS LAST,o_year@1 DESC]
+--SortPreservingMergeExec: [nation@0 ASC NULLS LAST,o_year@1 DESC], fetch=10
----SortExec: fetch=10, expr=[nation@0 ASC NULLS LAST,o_year@1 DESC]
------ProjectionExec: expr=[nation@0 as nation, o_year@1 as o_year,
SUM(profit.amount)@2 as sum_profit]
--------AggregateExec: mode=FinalPartitioned, gby=[nation@0 as nation,
o_year@1 as o_year], aggr=[SUM(profit.amount)]
diff --git a/datafusion/core/tests/sqllogictests/test_files/union.slt
b/datafusion/core/tests/sqllogictests/test_files/union.slt
index 94c9eef893..2b3022ddd1 100644
--- a/datafusion/core/tests/sqllogictests/test_files/union.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/union.slt
@@ -308,7 +308,7 @@ Limit: skip=0, fetch=5
--------TableScan: aggregate_test_100 projection=[c1, c3]
physical_plan
GlobalLimitExec: skip=0, fetch=5
---SortPreservingMergeExec: [c9@1 DESC]
+--SortPreservingMergeExec: [c9@1 DESC], fetch=5
----UnionExec
------SortExec: expr=[c9@1 DESC]
--------ProjectionExec: expr=[c1@0 as c1, CAST(c9@1 AS Int64) as c9]
diff --git a/datafusion/core/tests/sqllogictests/test_files/window.slt
b/datafusion/core/tests/sqllogictests/test_files/window.slt
index 08d1a5616e..d77df127a8 100644
--- a/datafusion/core/tests/sqllogictests/test_files/window.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/window.slt
@@ -1792,7 +1792,7 @@ Limit: skip=0, fetch=5
------------TableScan: aggregate_test_100 projection=[c2, c3, c9]
physical_plan
GlobalLimitExec: skip=0, fetch=5
---SortPreservingMergeExec: [c3@0 ASC NULLS LAST]
+--SortPreservingMergeExec: [c3@0 ASC NULLS LAST], fetch=5
----ProjectionExec: expr=[c3@0 as c3, SUM(aggregate_test_100.c9) ORDER BY
[aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS
FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING
AND CURRENT ROW@2 as sum1, SUM(aggregate_test_100.c9) PARTITION BY
[aggregate_test_100.c3] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE
BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum2]
------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name:
"SUM(aggregate_test_100.c9)", data_type: UInt64, nullable: true, dict_id: 0,
dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range,
start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow }], mode=[Sorted]
--------SortExec: expr=[c3@0 ASC NULLS LAST,c9@1 DESC]
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index de334dc4a5..0d61cd2b35 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -1366,6 +1366,8 @@ message SortExecNode {
message SortPreservingMergeExecNode {
PhysicalPlanNode input = 1;
repeated PhysicalExprNode expr = 2;
+ // Maximum number of highest/lowest rows to fetch; negative means no limit
+ int64 fetch = 3;
}
message CoalesceBatchesExecNode {
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index 1cf08be321..831dd49618 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -20269,6 +20269,9 @@ impl serde::Serialize for SortPreservingMergeExecNode {
if !self.expr.is_empty() {
len += 1;
}
+ if self.fetch != 0 {
+ len += 1;
+ }
let mut struct_ser =
serializer.serialize_struct("datafusion.SortPreservingMergeExecNode", len)?;
if let Some(v) = self.input.as_ref() {
struct_ser.serialize_field("input", v)?;
@@ -20276,6 +20279,9 @@ impl serde::Serialize for SortPreservingMergeExecNode {
if !self.expr.is_empty() {
struct_ser.serialize_field("expr", &self.expr)?;
}
+ if self.fetch != 0 {
+ struct_ser.serialize_field("fetch",
ToString::to_string(&self.fetch).as_str())?;
+ }
struct_ser.end()
}
}
@@ -20288,12 +20294,14 @@ impl<'de> serde::Deserialize<'de> for
SortPreservingMergeExecNode {
const FIELDS: &[&str] = &[
"input",
"expr",
+ "fetch",
];
#[allow(clippy::enum_variant_names)]
enum GeneratedField {
Input,
Expr,
+ Fetch,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
@@ -20317,6 +20325,7 @@ impl<'de> serde::Deserialize<'de> for
SortPreservingMergeExecNode {
match value {
"input" => Ok(GeneratedField::Input),
"expr" => Ok(GeneratedField::Expr),
+ "fetch" => Ok(GeneratedField::Fetch),
_ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
}
}
@@ -20338,6 +20347,7 @@ impl<'de> serde::Deserialize<'de> for
SortPreservingMergeExecNode {
{
let mut input__ = None;
let mut expr__ = None;
+ let mut fetch__ = None;
while let Some(k) = map.next_key()? {
match k {
GeneratedField::Input => {
@@ -20352,11 +20362,20 @@ impl<'de> serde::Deserialize<'de> for
SortPreservingMergeExecNode {
}
expr__ = Some(map.next_value()?);
}
+ GeneratedField::Fetch => {
+ if fetch__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("fetch"));
+ }
+ fetch__ =
+
Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0)
+ ;
+ }
}
}
Ok(SortPreservingMergeExecNode {
input: input__,
expr: expr__.unwrap_or_default(),
+ fetch: fetch__.unwrap_or_default(),
})
}
}
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index 5f201b124d..e6c076e7d4 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1926,6 +1926,9 @@ pub struct SortPreservingMergeExecNode {
pub input:
::core::option::Option<::prost::alloc::boxed::Box<PhysicalPlanNode>>,
#[prost(message, repeated, tag = "2")]
pub expr: ::prost::alloc::vec::Vec<PhysicalExprNode>,
+ /// Maximum number of highest/lowest rows to fetch; negative means no limit
+ #[prost(int64, tag = "3")]
+ pub fetch: i64,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
diff --git a/datafusion/proto/src/physical_plan/mod.rs
b/datafusion/proto/src/physical_plan/mod.rs
index 1daa1c2e4b..7bbbe13568 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -692,7 +692,14 @@ impl AsExecutionPlan for PhysicalPlanNode {
}
})
.collect::<Result<Vec<_>, _>>()?;
- Ok(Arc::new(SortPreservingMergeExec::new(exprs, input)))
+ let fetch = if sort.fetch < 0 {
+ None
+ } else {
+ Some(sort.fetch as usize)
+ };
+ Ok(Arc::new(
+ SortPreservingMergeExec::new(exprs,
input).with_fetch(fetch),
+ ))
}
PhysicalPlanType::Extension(extension) => {
let inputs: Vec<Arc<dyn ExecutionPlan>> = extension
@@ -1144,6 +1151,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
Box::new(protobuf::SortPreservingMergeExecNode {
input: Some(Box::new(input)),
expr,
+ fetch: exec.fetch().map(|f| f as i64).unwrap_or(-1),
}),
)),
})