This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 05d5f01 implement window functions with partition by (#558)
05d5f01 is described below
commit 05d5f01fa8ec7bf9baa3aa632ccedb914d0b49a2
Author: Jiayu Liu <[email protected]>
AuthorDate: Mon Jun 21 18:57:04 2021 +0800
implement window functions with partition by (#558)
---
datafusion/src/execution/context.rs | 74 ++++++++++++++++++++++
.../src/physical_plan/expressions/nth_value.rs | 10 ++-
datafusion/src/physical_plan/mod.rs | 36 +++++++----
datafusion/src/physical_plan/planner.rs | 6 --
datafusion/src/physical_plan/windows.rs | 61 ++++++++++++++++--
datafusion/tests/sql.rs | 64 +++++++++++++++++++
.../sqls/simple_window_partition_aggregation.sql | 26 ++++++++
.../simple_window_partition_order_aggregation.sql | 26 ++++++++
integration-tests/test_psql_parity.py | 2 +-
9 files changed, 275 insertions(+), 30 deletions(-)
diff --git a/datafusion/src/execution/context.rs
b/datafusion/src/execution/context.rs
index ef652c2..b42695b 100644
--- a/datafusion/src/execution/context.rs
+++ b/datafusion/src/execution/context.rs
@@ -1356,6 +1356,80 @@ mod tests {
}
#[tokio::test]
+ async fn window_partition_by() -> Result<()> {
+ let results = execute(
+ "SELECT \
+ c1, \
+ c2, \
+ SUM(c2) OVER (PARTITION BY c2), \
+ COUNT(c2) OVER (PARTITION BY c2), \
+ MAX(c2) OVER (PARTITION BY c2), \
+ MIN(c2) OVER (PARTITION BY c2), \
+ AVG(c2) OVER (PARTITION BY c2) \
+ FROM test \
+ ORDER BY c1, c2 \
+ LIMIT 5",
+ 4,
+ )
+ .await?;
+
+ let expected = vec![
+ "+----+----+---------+-----------+---------+---------+---------+",
+ "| c1 | c2 | SUM(c2) | COUNT(c2) | MAX(c2) | MIN(c2) | AVG(c2) |",
+ "+----+----+---------+-----------+---------+---------+---------+",
+ "| 0 | 1 | 4 | 4 | 1 | 1 | 1 |",
+ "| 0 | 2 | 8 | 4 | 2 | 2 | 2 |",
+ "| 0 | 3 | 12 | 4 | 3 | 3 | 3 |",
+ "| 0 | 4 | 16 | 4 | 4 | 4 | 4 |",
+ "| 0 | 5 | 20 | 4 | 5 | 5 | 5 |",
+ "+----+----+---------+-----------+---------+---------+---------+",
+ ];
+
+ // window function shall respect ordering
+ assert_batches_eq!(expected, &results);
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn window_partition_by_order_by() -> Result<()> {
+ let results = execute(
+ "SELECT \
+ c1, \
+ c2, \
+ ROW_NUMBER() OVER (PARTITION BY c2 ORDER BY c1), \
+ FIRST_VALUE(c2 + c1) OVER (PARTITION BY c2 ORDER BY c1), \
+ LAST_VALUE(c2 + c1) OVER (PARTITION BY c2 ORDER BY c1), \
+ NTH_VALUE(c2 + c1, 2) OVER (PARTITION BY c2 ORDER BY c1), \
+ SUM(c2) OVER (PARTITION BY c2 ORDER BY c1), \
+ COUNT(c2) OVER (PARTITION BY c2 ORDER BY c1), \
+ MAX(c2) OVER (PARTITION BY c2 ORDER BY c1), \
+ MIN(c2) OVER (PARTITION BY c2 ORDER BY c1), \
+ AVG(c2) OVER (PARTITION BY c2 ORDER BY c1) \
+ FROM test \
+ ORDER BY c1, c2 \
+ LIMIT 5",
+ 4,
+ )
+ .await?;
+
+ let expected = vec![
+
"+----+----+--------------+-------------------------+------------------------+--------------------------------+---------+-----------+---------+---------+---------+",
+ "| c1 | c2 | ROW_NUMBER() | FIRST_VALUE(c2 Plus c1) |
LAST_VALUE(c2 Plus c1) | NTH_VALUE(c2 Plus c1,Int64(2)) | SUM(c2) | COUNT(c2) |
MAX(c2) | MIN(c2) | AVG(c2) |",
+
"+----+----+--------------+-------------------------+------------------------+--------------------------------+---------+-----------+---------+---------+---------+",
+ "| 0 | 1 | 1 | 1 | 4
| 2 | 1 | 1 | 1 | 1
| 1 |",
+ "| 0 | 2 | 1 | 2 | 5
| 3 | 2 | 1 | 2 | 2
| 2 |",
+ "| 0 | 3 | 1 | 3 | 6
| 4 | 3 | 1 | 3 | 3
| 3 |",
+ "| 0 | 4 | 1 | 4 | 7
| 5 | 4 | 1 | 4 | 4
| 4 |",
+ "| 0 | 5 | 1 | 5 | 8
| 6 | 5 | 1 | 5 | 5
| 5 |",
+
"+----+----+--------------+-------------------------+------------------------+--------------------------------+---------+-----------+---------+---------+---------+",
+ ];
+
+ // window function shall respect ordering
+ assert_batches_eq!(expected, &results);
+ Ok(())
+ }
+
+ #[tokio::test]
async fn aggregate() -> Result<()> {
let results = execute("SELECT SUM(c1), SUM(c2) FROM test", 4).await?;
assert_eq!(results.len(), 1);
diff --git a/datafusion/src/physical_plan/expressions/nth_value.rs
b/datafusion/src/physical_plan/expressions/nth_value.rs
index 98083fa..16897d4 100644
--- a/datafusion/src/physical_plan/expressions/nth_value.rs
+++ b/datafusion/src/physical_plan/expressions/nth_value.rs
@@ -20,7 +20,7 @@
use crate::error::{DataFusionError, Result};
use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr,
PhysicalExpr};
use crate::scalar::ScalarValue;
-use arrow::array::{new_empty_array, ArrayRef};
+use arrow::array::{new_empty_array, new_null_array, ArrayRef};
use arrow::datatypes::{DataType, Field};
use std::any::Any;
use std::sync::Arc;
@@ -135,8 +135,12 @@ impl BuiltInWindowFunctionExpr for NthValue {
NthValueKind::Last => (num_rows as usize) - 1,
NthValueKind::Nth(n) => (n as usize) - 1,
};
- let value = ScalarValue::try_from_array(value, index)?;
- Ok(value.to_array_of_size(num_rows))
+ Ok(if index >= num_rows {
+ new_null_array(value.data_type(), num_rows)
+ } else {
+ let value = ScalarValue::try_from_array(value, index)?;
+ value.to_array_of_size(num_rows)
+ })
}
}
diff --git a/datafusion/src/physical_plan/mod.rs
b/datafusion/src/physical_plan/mod.rs
index 713956f..50c30a5 100644
--- a/datafusion/src/physical_plan/mod.rs
+++ b/datafusion/src/physical_plan/mod.rs
@@ -485,19 +485,20 @@ pub trait WindowExpr: Send + Sync + Debug {
/// evaluate the window function values against the batch
fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef>;
- /// evaluate the sort partition points
- fn evaluate_sort_partition_points(
+ /// evaluate the partition points given the sort columns; if the sort
columns are
+ /// empty then the result will be a single element vec of the whole column
rows.
+ fn evaluate_partition_points(
&self,
- batch: &RecordBatch,
+ num_rows: usize,
+ partition_columns: &[SortColumn],
) -> Result<Vec<Range<usize>>> {
- let sort_columns = self.sort_columns(batch)?;
- if sort_columns.is_empty() {
+ if partition_columns.is_empty() {
Ok(vec![Range {
start: 0,
- end: batch.num_rows(),
+ end: num_rows,
}])
} else {
- lexicographical_partition_ranges(&sort_columns)
+ lexicographical_partition_ranges(partition_columns)
.map_err(DataFusionError::ArrowError)
}
}
@@ -508,8 +509,8 @@ pub trait WindowExpr: Send + Sync + Debug {
/// expressions that's from the window function's order by clause, empty
if absent
fn order_by(&self) -> &[PhysicalSortExpr];
- /// get sort columns that can be used for partitioning, empty if absent
- fn sort_columns(&self, batch: &RecordBatch) -> Result<Vec<SortColumn>> {
+ /// get partition columns that can be used for partitioning, empty if
absent
+ fn partition_columns(&self, batch: &RecordBatch) ->
Result<Vec<SortColumn>> {
self.partition_by()
.iter()
.map(|expr| {
@@ -519,13 +520,20 @@ pub trait WindowExpr: Send + Sync + Debug {
}
.evaluate_to_sort_column(batch)
})
- .chain(
- self.order_by()
- .iter()
- .map(|e| e.evaluate_to_sort_column(batch)),
- )
.collect()
}
+
+ /// get sort columns that can be used for peer evaluation, empty if absent
+ fn sort_columns(&self, batch: &RecordBatch) -> Result<Vec<SortColumn>> {
+ let mut sort_columns = self.partition_columns(batch)?;
+ let order_by_columns = self
+ .order_by()
+ .iter()
+ .map(|e| e.evaluate_to_sort_column(batch))
+ .collect::<Result<Vec<SortColumn>>>()?;
+ sort_columns.extend(order_by_columns);
+ Ok(sort_columns)
+ }
}
/// An accumulator represents a stateful object that lives throughout the
evaluation of multiple rows and
diff --git a/datafusion/src/physical_plan/planner.rs
b/datafusion/src/physical_plan/planner.rs
index 1121c28..af0e60f 100644
--- a/datafusion/src/physical_plan/planner.rs
+++ b/datafusion/src/physical_plan/planner.rs
@@ -775,12 +775,6 @@ impl DefaultPhysicalPlanner {
)),
})
.collect::<Result<Vec<_>>>()?;
- if !partition_by.is_empty() {
- return Err(DataFusionError::NotImplemented(
- "window expression with non-empty partition by
clause is not yet supported"
- .to_owned(),
- ));
- }
if window_frame.is_some() {
return Err(DataFusionError::NotImplemented(
"window expression with window frame definition is
not yet supported"
diff --git a/datafusion/src/physical_plan/windows.rs
b/datafusion/src/physical_plan/windows.rs
index e557097..466cc51 100644
--- a/datafusion/src/physical_plan/windows.rs
+++ b/datafusion/src/physical_plan/windows.rs
@@ -175,10 +175,45 @@ impl WindowExpr for BuiltInWindowExpr {
// case when partition_by is supported, in which case we'll
parallelize the calls.
// See https://github.com/apache/arrow-datafusion/issues/299
let values = self.evaluate_args(batch)?;
- self.window.evaluate(batch.num_rows(), &values)
+ let partition_points = self.evaluate_partition_points(
+ batch.num_rows(),
+ &self.partition_columns(batch)?,
+ )?;
+ let results = partition_points
+ .iter()
+ .map(|partition_range| {
+ let start = partition_range.start;
+ let len = partition_range.end - start;
+ let values = values
+ .iter()
+ .map(|arr| arr.slice(start, len))
+ .collect::<Vec<_>>();
+ self.window.evaluate(len, &values)
+ })
+ .collect::<Result<Vec<_>>>()?
+ .into_iter()
+ .collect::<Vec<ArrayRef>>();
+ let results = results.iter().map(|i| i.as_ref()).collect::<Vec<_>>();
+ concat(&results).map_err(DataFusionError::ArrowError)
}
}
+/// Given a partition range, and the full list of sort partition points, given
that the sort
+/// partition points are sorted using [partition columns..., order
columns...], the split
+/// boundaries would align (what's sorted on [partition columns...] would
definitely be sorted
+/// on finer columns), so this will use binary search to find ranges that are
within the
+/// partition range and return the valid slice.
+fn find_ranges_in_range<'a>(
+ partition_range: &Range<usize>,
+ sort_partition_points: &'a [Range<usize>],
+) -> &'a [Range<usize>] {
+ let start_idx = sort_partition_points
+ .partition_point(|sort_range| sort_range.start <
partition_range.start);
+ let end_idx = sort_partition_points
+ .partition_point(|sort_range| sort_range.end <= partition_range.end);
+ &sort_partition_points[start_idx..end_idx]
+}
+
/// A window expr that takes the form of an aggregate function
#[derive(Debug)]
pub struct AggregateWindowExpr {
@@ -205,13 +240,27 @@ impl AggregateWindowExpr {
/// and then per partition point we'll evaluate the peer group (e.g. SUM
or MAX gives the same
/// results for peers) and concatenate the results.
fn peer_based_evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
- let sort_partition_points =
self.evaluate_sort_partition_points(batch)?;
- let mut window_accumulators = self.create_accumulator()?;
+ let num_rows = batch.num_rows();
+ let partition_points =
+ self.evaluate_partition_points(num_rows,
&self.partition_columns(batch)?)?;
+ let sort_partition_points =
+ self.evaluate_partition_points(num_rows,
&self.sort_columns(batch)?)?;
let values = self.evaluate_args(batch)?;
- let results = sort_partition_points
+ let results = partition_points
.iter()
- .map(|peer_range| window_accumulators.scan_peers(&values,
peer_range))
- .collect::<Result<Vec<_>>>()?;
+ .map(|partition_range| {
+ let sort_partition_points =
+ find_ranges_in_range(partition_range,
&sort_partition_points);
+ let mut window_accumulators = self.create_accumulator()?;
+ sort_partition_points
+ .iter()
+ .map(|range| window_accumulators.scan_peers(&values,
range))
+ .collect::<Result<Vec<_>>>()
+ })
+ .collect::<Result<Vec<Vec<ArrayRef>>>>()?
+ .into_iter()
+ .flatten()
+ .collect::<Vec<ArrayRef>>();
let results = results.iter().map(|i| i.as_ref()).collect::<Vec<_>>();
concat(&results).map_err(DataFusionError::ArrowError)
}
diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs
index b6393e9..cfdb6f4 100644
--- a/datafusion/tests/sql.rs
+++ b/datafusion/tests/sql.rs
@@ -869,6 +869,70 @@ async fn csv_query_window_with_empty_over() -> Result<()> {
}
#[tokio::test]
+async fn csv_query_window_with_partition_by() -> Result<()> {
+ let mut ctx = ExecutionContext::new();
+ register_aggregate_csv(&mut ctx)?;
+ let sql = "select \
+ c9, \
+ sum(cast(c4 as Int)) over (partition by c3), \
+ avg(cast(c4 as Int)) over (partition by c3), \
+ count(cast(c4 as Int)) over (partition by c3), \
+ max(cast(c4 as Int)) over (partition by c3), \
+ min(cast(c4 as Int)) over (partition by c3), \
+ first_value(cast(c4 as Int)) over (partition by c3), \
+ last_value(cast(c4 as Int)) over (partition by c3), \
+ nth_value(cast(c4 as Int), 2) over (partition by c3) \
+ from aggregate_test_100 \
+ order by c9 \
+ limit 5";
+ let actual = execute(&mut ctx, sql).await;
+ let expected = vec![
+ vec![
+ "28774375", "-16110", "-16110", "1", "-16110", "-16110", "-16110",
"-16110",
+ "NULL",
+ ],
+ vec![
+ "63044568", "3917", "3917", "1", "3917", "3917", "3917", "3917",
"NULL",
+ ],
+ vec![
+ "141047417",
+ "-38455",
+ "-19227.5",
+ "2",
+ "-16974",
+ "-21481",
+ "-16974",
+ "-21481",
+ "-21481",
+ ],
+ vec![
+ "141680161",
+ "-1114",
+ "-1114",
+ "1",
+ "-1114",
+ "-1114",
+ "-1114",
+ "-1114",
+ "NULL",
+ ],
+ vec![
+ "145294611",
+ "15673",
+ "15673",
+ "1",
+ "15673",
+ "15673",
+ "15673",
+ "15673",
+ "NULL",
+ ],
+ ];
+ assert_eq!(expected, actual);
+ Ok(())
+}
+
+#[tokio::test]
async fn csv_query_window_with_order_by() -> Result<()> {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx)?;
diff --git a/integration-tests/sqls/simple_window_partition_aggregation.sql
b/integration-tests/sqls/simple_window_partition_aggregation.sql
new file mode 100644
index 0000000..f395671
--- /dev/null
+++ b/integration-tests/sqls/simple_window_partition_aggregation.sql
@@ -0,0 +1,26 @@
+-- Licensed to the Apache Software Foundation (ASF) under one
+-- or more contributor license agreements. See the NOTICE file
+-- distributed with this work for additional information
+-- regarding copyright ownership. The ASF licenses this file
+-- to you under the Apache License, Version 2.0 (the
+-- "License"); you may not use this file except in compliance
+-- with the License. You may obtain a copy of the License at
+
+-- http://www.apache.org/licenses/LICENSE-2.0
+
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language gOVERning permissions and
+-- limitations under the License.
+
+SELECT
+ c9,
+ row_number() OVER (PARTITION BY c2, c9) AS row_number,
+ count(c3) OVER (PARTITION BY c2) AS count_c3,
+ avg(c3) OVER (PARTITION BY c2) AS avg_c3_by_c2,
+ sum(c3) OVER (PARTITION BY c2) AS sum_c3_by_c2,
+ max(c3) OVER (PARTITION BY c2) AS max_c3_by_c2,
+ min(c3) OVER (PARTITION BY c2) AS min_c3_by_c2
+FROM test
+ORDER BY c9;
diff --git
a/integration-tests/sqls/simple_window_partition_order_aggregation.sql
b/integration-tests/sqls/simple_window_partition_order_aggregation.sql
new file mode 100644
index 0000000..a11a9ec
--- /dev/null
+++ b/integration-tests/sqls/simple_window_partition_order_aggregation.sql
@@ -0,0 +1,26 @@
+-- Licensed to the Apache Software Foundation (ASF) under one
+-- or more contributor license agreements. See the NOTICE file
+-- distributed with this work for additional information
+-- regarding copyright ownership. The ASF licenses this file
+-- to you under the Apache License, Version 2.0 (the
+-- "License"); you may not use this file except in compliance
+-- with the License. You may obtain a copy of the License at
+
+-- http://www.apache.org/licenses/LICENSE-2.0
+
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language gOVERning permissions and
+-- limitations under the License.
+
+SELECT
+ c9,
+ row_number() OVER (PARTITION BY c2 ORDER BY c9) AS row_number,
+ count(c3) OVER (PARTITION BY c2 ORDER BY c9) AS count_c3,
+ avg(c3) OVER (PARTITION BY c2 ORDER BY c9) AS avg_c3_by_c2,
+ sum(c3) OVER (PARTITION BY c2 ORDER BY c9) AS sum_c3_by_c2,
+ max(c3) OVER (PARTITION BY c2 ORDER BY c9) AS max_c3_by_c2,
+ min(c3) OVER (PARTITION BY c2 ORDER BY c9) AS min_c3_by_c2
+FROM test
+ORDER BY c9;
diff --git a/integration-tests/test_psql_parity.py
b/integration-tests/test_psql_parity.py
index 4e0878c..c4b5a75 100644
--- a/integration-tests/test_psql_parity.py
+++ b/integration-tests/test_psql_parity.py
@@ -74,7 +74,7 @@ class PsqlParityTest(unittest.TestCase):
def test_parity(self):
root = Path(os.path.dirname(__file__)) / "sqls"
files = set(root.glob("*.sql"))
- self.assertEqual(len(files), 7, msg="tests are missed")
+ self.assertEqual(len(files), 9, msg="tests are missed")
for fname in files:
with self.subTest(fname=fname):
datafusion_output = pd.read_csv(