This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 05d5f01  implement window functions with partition by (#558)
05d5f01 is described below

commit 05d5f01fa8ec7bf9baa3aa632ccedb914d0b49a2
Author: Jiayu Liu <[email protected]>
AuthorDate: Mon Jun 21 18:57:04 2021 +0800

    implement window functions with partition by (#558)
---
 datafusion/src/execution/context.rs                | 74 ++++++++++++++++++++++
 .../src/physical_plan/expressions/nth_value.rs     | 10 ++-
 datafusion/src/physical_plan/mod.rs                | 36 +++++++----
 datafusion/src/physical_plan/planner.rs            |  6 --
 datafusion/src/physical_plan/windows.rs            | 61 ++++++++++++++++--
 datafusion/tests/sql.rs                            | 64 +++++++++++++++++++
 .../sqls/simple_window_partition_aggregation.sql   | 26 ++++++++
 .../simple_window_partition_order_aggregation.sql  | 26 ++++++++
 integration-tests/test_psql_parity.py              |  2 +-
 9 files changed, 275 insertions(+), 30 deletions(-)

diff --git a/datafusion/src/execution/context.rs 
b/datafusion/src/execution/context.rs
index ef652c2..b42695b 100644
--- a/datafusion/src/execution/context.rs
+++ b/datafusion/src/execution/context.rs
@@ -1356,6 +1356,80 @@ mod tests {
     }
 
     #[tokio::test]
+    async fn window_partition_by() -> Result<()> {
+        let results = execute(
+            "SELECT \
+            c1, \
+            c2, \
+            SUM(c2) OVER (PARTITION BY c2), \
+            COUNT(c2) OVER (PARTITION BY c2), \
+            MAX(c2) OVER (PARTITION BY c2), \
+            MIN(c2) OVER (PARTITION BY c2), \
+            AVG(c2) OVER (PARTITION BY c2) \
+            FROM test \
+            ORDER BY c1, c2 \
+            LIMIT 5",
+            4,
+        )
+        .await?;
+
+        let expected = vec![
+            "+----+----+---------+-----------+---------+---------+---------+",
+            "| c1 | c2 | SUM(c2) | COUNT(c2) | MAX(c2) | MIN(c2) | AVG(c2) |",
+            "+----+----+---------+-----------+---------+---------+---------+",
+            "| 0  | 1  | 4       | 4         | 1       | 1       | 1       |",
+            "| 0  | 2  | 8       | 4         | 2       | 2       | 2       |",
+            "| 0  | 3  | 12      | 4         | 3       | 3       | 3       |",
+            "| 0  | 4  | 16      | 4         | 4       | 4       | 4       |",
+            "| 0  | 5  | 20      | 4         | 5       | 5       | 5       |",
+            "+----+----+---------+-----------+---------+---------+---------+",
+        ];
+
+        // window function shall respect ordering
+        assert_batches_eq!(expected, &results);
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn window_partition_by_order_by() -> Result<()> {
+        let results = execute(
+            "SELECT \
+            c1, \
+            c2, \
+            ROW_NUMBER() OVER (PARTITION BY c2 ORDER BY c1), \
+            FIRST_VALUE(c2 + c1) OVER (PARTITION BY c2 ORDER BY c1), \
+            LAST_VALUE(c2 + c1) OVER (PARTITION BY c2 ORDER BY c1), \
+            NTH_VALUE(c2 + c1, 2) OVER (PARTITION BY c2 ORDER BY c1), \
+            SUM(c2) OVER (PARTITION BY c2 ORDER BY c1), \
+            COUNT(c2) OVER (PARTITION BY c2 ORDER BY c1), \
+            MAX(c2) OVER (PARTITION BY c2 ORDER BY c1), \
+            MIN(c2) OVER (PARTITION BY c2 ORDER BY c1), \
+            AVG(c2) OVER (PARTITION BY c2 ORDER BY c1) \
+            FROM test \
+            ORDER BY c1, c2 \
+            LIMIT 5",
+            4,
+        )
+        .await?;
+
+        let expected = vec![
+            
"+----+----+--------------+-------------------------+------------------------+--------------------------------+---------+-----------+---------+---------+---------+",
+            "| c1 | c2 | ROW_NUMBER() | FIRST_VALUE(c2 Plus c1) | 
LAST_VALUE(c2 Plus c1) | NTH_VALUE(c2 Plus c1,Int64(2)) | SUM(c2) | COUNT(c2) | 
MAX(c2) | MIN(c2) | AVG(c2) |",
+            
"+----+----+--------------+-------------------------+------------------------+--------------------------------+---------+-----------+---------+---------+---------+",
+            "| 0  | 1  | 1            | 1                       | 4            
          | 2                              | 1       | 1         | 1       | 1  
     | 1       |",
+            "| 0  | 2  | 1            | 2                       | 5            
          | 3                              | 2       | 1         | 2       | 2  
     | 2       |",
+            "| 0  | 3  | 1            | 3                       | 6            
          | 4                              | 3       | 1         | 3       | 3  
     | 3       |",
+            "| 0  | 4  | 1            | 4                       | 7            
          | 5                              | 4       | 1         | 4       | 4  
     | 4       |",
+            "| 0  | 5  | 1            | 5                       | 8            
          | 6                              | 5       | 1         | 5       | 5  
     | 5       |",
+            
"+----+----+--------------+-------------------------+------------------------+--------------------------------+---------+-----------+---------+---------+---------+",
+        ];
+
+        // window function shall respect ordering
+        assert_batches_eq!(expected, &results);
+        Ok(())
+    }
+
+    #[tokio::test]
     async fn aggregate() -> Result<()> {
         let results = execute("SELECT SUM(c1), SUM(c2) FROM test", 4).await?;
         assert_eq!(results.len(), 1);
diff --git a/datafusion/src/physical_plan/expressions/nth_value.rs 
b/datafusion/src/physical_plan/expressions/nth_value.rs
index 98083fa..16897d4 100644
--- a/datafusion/src/physical_plan/expressions/nth_value.rs
+++ b/datafusion/src/physical_plan/expressions/nth_value.rs
@@ -20,7 +20,7 @@
 use crate::error::{DataFusionError, Result};
 use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr, 
PhysicalExpr};
 use crate::scalar::ScalarValue;
-use arrow::array::{new_empty_array, ArrayRef};
+use arrow::array::{new_empty_array, new_null_array, ArrayRef};
 use arrow::datatypes::{DataType, Field};
 use std::any::Any;
 use std::sync::Arc;
@@ -135,8 +135,12 @@ impl BuiltInWindowFunctionExpr for NthValue {
             NthValueKind::Last => (num_rows as usize) - 1,
             NthValueKind::Nth(n) => (n as usize) - 1,
         };
-        let value = ScalarValue::try_from_array(value, index)?;
-        Ok(value.to_array_of_size(num_rows))
+        Ok(if index >= num_rows {
+            new_null_array(value.data_type(), num_rows)
+        } else {
+            let value = ScalarValue::try_from_array(value, index)?;
+            value.to_array_of_size(num_rows)
+        })
     }
 }
 
diff --git a/datafusion/src/physical_plan/mod.rs 
b/datafusion/src/physical_plan/mod.rs
index 713956f..50c30a5 100644
--- a/datafusion/src/physical_plan/mod.rs
+++ b/datafusion/src/physical_plan/mod.rs
@@ -485,19 +485,20 @@ pub trait WindowExpr: Send + Sync + Debug {
     /// evaluate the window function values against the batch
     fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef>;
 
-    /// evaluate the sort partition points
-    fn evaluate_sort_partition_points(
+    /// evaluate the partition points given the sort columns; if the sort 
columns are
+    /// empty then the result will be a single element vec of the whole column 
rows.
+    fn evaluate_partition_points(
         &self,
-        batch: &RecordBatch,
+        num_rows: usize,
+        partition_columns: &[SortColumn],
     ) -> Result<Vec<Range<usize>>> {
-        let sort_columns = self.sort_columns(batch)?;
-        if sort_columns.is_empty() {
+        if partition_columns.is_empty() {
             Ok(vec![Range {
                 start: 0,
-                end: batch.num_rows(),
+                end: num_rows,
             }])
         } else {
-            lexicographical_partition_ranges(&sort_columns)
+            lexicographical_partition_ranges(partition_columns)
                 .map_err(DataFusionError::ArrowError)
         }
     }
@@ -508,8 +509,8 @@ pub trait WindowExpr: Send + Sync + Debug {
     /// expressions that's from the window function's order by clause, empty 
if absent
     fn order_by(&self) -> &[PhysicalSortExpr];
 
-    /// get sort columns that can be used for partitioning, empty if absent
-    fn sort_columns(&self, batch: &RecordBatch) -> Result<Vec<SortColumn>> {
+    /// get partition columns that can be used for partitioning, empty if 
absent
+    fn partition_columns(&self, batch: &RecordBatch) -> 
Result<Vec<SortColumn>> {
         self.partition_by()
             .iter()
             .map(|expr| {
@@ -519,13 +520,20 @@ pub trait WindowExpr: Send + Sync + Debug {
                 }
                 .evaluate_to_sort_column(batch)
             })
-            .chain(
-                self.order_by()
-                    .iter()
-                    .map(|e| e.evaluate_to_sort_column(batch)),
-            )
             .collect()
     }
+
+    /// get sort columns that can be used for peer evaluation, empty if absent
+    fn sort_columns(&self, batch: &RecordBatch) -> Result<Vec<SortColumn>> {
+        let mut sort_columns = self.partition_columns(batch)?;
+        let order_by_columns = self
+            .order_by()
+            .iter()
+            .map(|e| e.evaluate_to_sort_column(batch))
+            .collect::<Result<Vec<SortColumn>>>()?;
+        sort_columns.extend(order_by_columns);
+        Ok(sort_columns)
+    }
 }
 
 /// An accumulator represents a stateful object that lives throughout the 
evaluation of multiple rows and
diff --git a/datafusion/src/physical_plan/planner.rs 
b/datafusion/src/physical_plan/planner.rs
index 1121c28..af0e60f 100644
--- a/datafusion/src/physical_plan/planner.rs
+++ b/datafusion/src/physical_plan/planner.rs
@@ -775,12 +775,6 @@ impl DefaultPhysicalPlanner {
                         )),
                     })
                     .collect::<Result<Vec<_>>>()?;
-                if !partition_by.is_empty() {
-                    return Err(DataFusionError::NotImplemented(
-                            "window expression with non-empty partition by 
clause is not yet supported"
-                                .to_owned(),
-                        ));
-                }
                 if window_frame.is_some() {
                     return Err(DataFusionError::NotImplemented(
                             "window expression with window frame definition is 
not yet supported"
diff --git a/datafusion/src/physical_plan/windows.rs 
b/datafusion/src/physical_plan/windows.rs
index e557097..466cc51 100644
--- a/datafusion/src/physical_plan/windows.rs
+++ b/datafusion/src/physical_plan/windows.rs
@@ -175,10 +175,45 @@ impl WindowExpr for BuiltInWindowExpr {
         // case when partition_by is supported, in which case we'll 
parallelize the calls.
         // See https://github.com/apache/arrow-datafusion/issues/299
         let values = self.evaluate_args(batch)?;
-        self.window.evaluate(batch.num_rows(), &values)
+        let partition_points = self.evaluate_partition_points(
+            batch.num_rows(),
+            &self.partition_columns(batch)?,
+        )?;
+        let results = partition_points
+            .iter()
+            .map(|partition_range| {
+                let start = partition_range.start;
+                let len = partition_range.end - start;
+                let values = values
+                    .iter()
+                    .map(|arr| arr.slice(start, len))
+                    .collect::<Vec<_>>();
+                self.window.evaluate(len, &values)
+            })
+            .collect::<Result<Vec<_>>>()?
+            .into_iter()
+            .collect::<Vec<ArrayRef>>();
+        let results = results.iter().map(|i| i.as_ref()).collect::<Vec<_>>();
+        concat(&results).map_err(DataFusionError::ArrowError)
     }
 }
 
+/// Given a partition range, and the full list of sort partition points, given 
that the sort
+/// partition points are sorted using [partition columns..., order 
columns...], the split
+/// boundaries would align (what's sorted on [partition columns...] would 
definitely be sorted
+/// on finer columns), so this will use binary search to find ranges that are 
within the
+/// partition range and return the valid slice.
+fn find_ranges_in_range<'a>(
+    partition_range: &Range<usize>,
+    sort_partition_points: &'a [Range<usize>],
+) -> &'a [Range<usize>] {
+    let start_idx = sort_partition_points
+        .partition_point(|sort_range| sort_range.start < 
partition_range.start);
+    let end_idx = sort_partition_points
+        .partition_point(|sort_range| sort_range.end <= partition_range.end);
+    &sort_partition_points[start_idx..end_idx]
+}
+
 /// A window expr that takes the form of an aggregate function
 #[derive(Debug)]
 pub struct AggregateWindowExpr {
@@ -205,13 +240,27 @@ impl AggregateWindowExpr {
     /// and then per partition point we'll evaluate the peer group (e.g. SUM 
or MAX gives the same
     /// results for peers) and concatenate the results.
     fn peer_based_evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
-        let sort_partition_points = 
self.evaluate_sort_partition_points(batch)?;
-        let mut window_accumulators = self.create_accumulator()?;
+        let num_rows = batch.num_rows();
+        let partition_points =
+            self.evaluate_partition_points(num_rows, 
&self.partition_columns(batch)?)?;
+        let sort_partition_points =
+            self.evaluate_partition_points(num_rows, 
&self.sort_columns(batch)?)?;
         let values = self.evaluate_args(batch)?;
-        let results = sort_partition_points
+        let results = partition_points
             .iter()
-            .map(|peer_range| window_accumulators.scan_peers(&values, 
peer_range))
-            .collect::<Result<Vec<_>>>()?;
+            .map(|partition_range| {
+                let sort_partition_points =
+                    find_ranges_in_range(partition_range, 
&sort_partition_points);
+                let mut window_accumulators = self.create_accumulator()?;
+                sort_partition_points
+                    .iter()
+                    .map(|range| window_accumulators.scan_peers(&values, 
range))
+                    .collect::<Result<Vec<_>>>()
+            })
+            .collect::<Result<Vec<Vec<ArrayRef>>>>()?
+            .into_iter()
+            .flatten()
+            .collect::<Vec<ArrayRef>>();
         let results = results.iter().map(|i| i.as_ref()).collect::<Vec<_>>();
         concat(&results).map_err(DataFusionError::ArrowError)
     }
diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs
index b6393e9..cfdb6f4 100644
--- a/datafusion/tests/sql.rs
+++ b/datafusion/tests/sql.rs
@@ -869,6 +869,70 @@ async fn csv_query_window_with_empty_over() -> Result<()> {
 }
 
 #[tokio::test]
+async fn csv_query_window_with_partition_by() -> Result<()> {
+    let mut ctx = ExecutionContext::new();
+    register_aggregate_csv(&mut ctx)?;
+    let sql = "select \
+    c9, \
+    sum(cast(c4 as Int)) over (partition by c3), \
+    avg(cast(c4 as Int)) over (partition by c3), \
+    count(cast(c4 as Int)) over (partition by c3), \
+    max(cast(c4 as Int)) over (partition by c3), \
+    min(cast(c4 as Int)) over (partition by c3), \
+    first_value(cast(c4 as Int)) over (partition by c3), \
+    last_value(cast(c4 as Int)) over (partition by c3), \
+    nth_value(cast(c4 as Int), 2) over (partition by c3) \
+    from aggregate_test_100 \
+    order by c9 \
+    limit 5";
+    let actual = execute(&mut ctx, sql).await;
+    let expected = vec![
+        vec![
+            "28774375", "-16110", "-16110", "1", "-16110", "-16110", "-16110", 
"-16110",
+            "NULL",
+        ],
+        vec![
+            "63044568", "3917", "3917", "1", "3917", "3917", "3917", "3917", 
"NULL",
+        ],
+        vec![
+            "141047417",
+            "-38455",
+            "-19227.5",
+            "2",
+            "-16974",
+            "-21481",
+            "-16974",
+            "-21481",
+            "-21481",
+        ],
+        vec![
+            "141680161",
+            "-1114",
+            "-1114",
+            "1",
+            "-1114",
+            "-1114",
+            "-1114",
+            "-1114",
+            "NULL",
+        ],
+        vec![
+            "145294611",
+            "15673",
+            "15673",
+            "1",
+            "15673",
+            "15673",
+            "15673",
+            "15673",
+            "NULL",
+        ],
+    ];
+    assert_eq!(expected, actual);
+    Ok(())
+}
+
+#[tokio::test]
 async fn csv_query_window_with_order_by() -> Result<()> {
     let mut ctx = ExecutionContext::new();
     register_aggregate_csv(&mut ctx)?;
diff --git a/integration-tests/sqls/simple_window_partition_aggregation.sql 
b/integration-tests/sqls/simple_window_partition_aggregation.sql
new file mode 100644
index 0000000..f395671
--- /dev/null
+++ b/integration-tests/sqls/simple_window_partition_aggregation.sql
@@ -0,0 +1,26 @@
+-- Licensed to the Apache Software Foundation (ASF) under one
+-- or more contributor license agreements.  See the NOTICE file
+-- distributed with this work for additional information
+-- regarding copyright ownership.  The ASF licenses this file
+-- to you under the Apache License, Version 2.0 (the
+-- "License"); you may not use this file except in compliance
+-- with the License.  You may obtain a copy of the License at
+
+-- http://www.apache.org/licenses/LICENSE-2.0
+
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language gOVERning permissions and
+-- limitations under the License.
+
+SELECT
+  c9,
+  row_number() OVER (PARTITION BY c2, c9) AS row_number,
+  count(c3) OVER (PARTITION BY c2) AS count_c3,
+  avg(c3) OVER (PARTITION BY c2) AS avg_c3_by_c2,
+  sum(c3) OVER (PARTITION BY c2) AS sum_c3_by_c2,
+  max(c3) OVER (PARTITION BY c2) AS max_c3_by_c2,
+  min(c3) OVER (PARTITION BY c2) AS min_c3_by_c2
+FROM test
+ORDER BY c9;
diff --git 
a/integration-tests/sqls/simple_window_partition_order_aggregation.sql 
b/integration-tests/sqls/simple_window_partition_order_aggregation.sql
new file mode 100644
index 0000000..a11a9ec
--- /dev/null
+++ b/integration-tests/sqls/simple_window_partition_order_aggregation.sql
@@ -0,0 +1,26 @@
+-- Licensed to the Apache Software Foundation (ASF) under one
+-- or more contributor license agreements.  See the NOTICE file
+-- distributed with this work for additional information
+-- regarding copyright ownership.  The ASF licenses this file
+-- to you under the Apache License, Version 2.0 (the
+-- "License"); you may not use this file except in compliance
+-- with the License.  You may obtain a copy of the License at
+
+-- http://www.apache.org/licenses/LICENSE-2.0
+
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language gOVERning permissions and
+-- limitations under the License.
+
+SELECT
+  c9,
+  row_number() OVER (PARTITION BY c2 ORDER BY c9) AS row_number,
+  count(c3) OVER (PARTITION BY c2 ORDER BY c9) AS count_c3,
+  avg(c3) OVER (PARTITION BY c2 ORDER BY c9) AS avg_c3_by_c2,
+  sum(c3) OVER (PARTITION BY c2 ORDER BY c9) AS sum_c3_by_c2,
+  max(c3) OVER (PARTITION BY c2 ORDER BY c9) AS max_c3_by_c2,
+  min(c3) OVER (PARTITION BY c2 ORDER BY c9) AS min_c3_by_c2
+FROM test
+ORDER BY c9;
diff --git a/integration-tests/test_psql_parity.py 
b/integration-tests/test_psql_parity.py
index 4e0878c..c4b5a75 100644
--- a/integration-tests/test_psql_parity.py
+++ b/integration-tests/test_psql_parity.py
@@ -74,7 +74,7 @@ class PsqlParityTest(unittest.TestCase):
     def test_parity(self):
         root = Path(os.path.dirname(__file__)) / "sqls"
         files = set(root.glob("*.sql"))
-        self.assertEqual(len(files), 7, msg="tests are missed")
+        self.assertEqual(len(files), 9, msg="tests are missed")
         for fname in files:
             with self.subTest(fname=fname):
                 datafusion_output = pd.read_csv(

Reply via email to