This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 8e12e48 add rank and dense rank and refactor window built in
functions (#631)
8e12e48 is described below
commit 8e12e482830afcd619fadca237f2c6412883a63d
Author: Jiayu Liu <[email protected]>
AuthorDate: Tue Jun 29 05:54:12 2021 +0800
add rank and dense rank and refactor window built in functions (#631)
---
datafusion/src/physical_plan/expressions/mod.rs | 2 +
.../src/physical_plan/expressions/nth_value.rs | 59 ++++---
datafusion/src/physical_plan/expressions/rank.rs | 172 +++++++++++++++++++++
.../src/physical_plan/expressions/row_number.rs | 32 +++-
datafusion/src/physical_plan/window_functions.rs | 63 +++++++-
datafusion/src/physical_plan/windows.rs | 55 ++++---
6 files changed, 318 insertions(+), 65 deletions(-)
diff --git a/datafusion/src/physical_plan/expressions/mod.rs
b/datafusion/src/physical_plan/expressions/mod.rs
index 0b32dca..440cb5b 100644
--- a/datafusion/src/physical_plan/expressions/mod.rs
+++ b/datafusion/src/physical_plan/expressions/mod.rs
@@ -42,6 +42,7 @@ mod negative;
mod not;
mod nth_value;
mod nullif;
+mod rank;
mod row_number;
mod sum;
mod try_cast;
@@ -63,6 +64,7 @@ pub use negative::{negative, NegativeExpr};
pub use not::{not, NotExpr};
pub use nth_value::NthValue;
pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES};
+pub use rank::{dense_rank, rank};
pub use row_number::RowNumber;
pub use sum::{sum_return_type, Sum};
pub use try_cast::{try_cast, TryCastExpr};
diff --git a/datafusion/src/physical_plan/expressions/nth_value.rs
b/datafusion/src/physical_plan/expressions/nth_value.rs
index b548f91..3897ae5 100644
--- a/datafusion/src/physical_plan/expressions/nth_value.rs
+++ b/datafusion/src/physical_plan/expressions/nth_value.rs
@@ -18,11 +18,14 @@
//! Defines physical expressions that can evaluated at runtime during query
execution
use crate::error::{DataFusionError, Result};
+use crate::physical_plan::window_functions::PartitionEvaluator;
use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr,
PhysicalExpr};
use crate::scalar::ScalarValue;
-use arrow::array::{new_empty_array, new_null_array, ArrayRef};
+use arrow::array::{new_null_array, ArrayRef};
use arrow::datatypes::{DataType, Field};
+use arrow::record_batch::RecordBatch;
use std::any::Any;
+use std::ops::Range;
use std::sync::Arc;
/// nth_value kind
@@ -111,25 +114,34 @@ impl BuiltInWindowFunctionExpr for NthValue {
&self.name
}
- fn evaluate(&self, num_rows: usize, values: &[ArrayRef]) ->
Result<ArrayRef> {
- if values.is_empty() {
- return Err(DataFusionError::Execution(format!(
- "No arguments supplied to {}",
- self.name()
- )));
- }
- let value = &values[0];
- if value.len() != num_rows {
- return Err(DataFusionError::Execution(format!(
- "Invalid data supplied to {}, expect {} rows, got {} rows",
- self.name(),
- num_rows,
- value.len()
- )));
- }
- if num_rows == 0 {
- return Ok(new_empty_array(value.data_type()));
- }
+ fn create_evaluator(
+ &self,
+ batch: &RecordBatch,
+ ) -> Result<Box<dyn PartitionEvaluator>> {
+ let values = self
+ .expressions()
+ .iter()
+ .map(|e| e.evaluate(batch))
+ .map(|r| r.map(|v| v.into_array(batch.num_rows())))
+ .collect::<Result<Vec<_>>>()?;
+ Ok(Box::new(NthValueEvaluator {
+ kind: self.kind,
+ values,
+ }))
+ }
+}
+
+/// Value evaluator for nth_value functions
+pub(crate) struct NthValueEvaluator {
+ kind: NthValueKind,
+ values: Vec<ArrayRef>,
+}
+
+impl PartitionEvaluator for NthValueEvaluator {
+ fn evaluate_partition(&self, partition: Range<usize>) -> Result<ArrayRef> {
+ let value = &self.values[0];
+ let num_rows = partition.end - partition.start;
+ let value = value.slice(partition.start, num_rows);
let index: usize = match self.kind {
NthValueKind::First => 0,
NthValueKind::Last => (num_rows as usize) - 1,
@@ -138,7 +150,7 @@ impl BuiltInWindowFunctionExpr for NthValue {
Ok(if index >= num_rows {
new_null_array(value.data_type(), num_rows)
} else {
- let value = ScalarValue::try_from_array(value, index)?;
+ let value = ScalarValue::try_from_array(&value, index)?;
value.to_array_of_size(num_rows)
})
}
@@ -157,8 +169,9 @@ mod tests {
let values = vec![arr];
let schema = Schema::new(vec![Field::new("arr", DataType::Int32,
false)]);
let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?;
- let result = expr.evaluate(batch.num_rows(), &values)?;
- let result = result.as_any().downcast_ref::<Int32Array>().unwrap();
+ let result = expr.create_evaluator(&batch)?.evaluate(vec![0..8])?;
+ assert_eq!(1, result.len());
+ let result = result[0].as_any().downcast_ref::<Int32Array>().unwrap();
let result = result.values();
assert_eq!(expected, result);
Ok(())
diff --git a/datafusion/src/physical_plan/expressions/rank.rs
b/datafusion/src/physical_plan/expressions/rank.rs
new file mode 100644
index 0000000..b88dec3
--- /dev/null
+++ b/datafusion/src/physical_plan/expressions/rank.rs
@@ -0,0 +1,172 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Defines physical expressions that can evaluated at runtime during query
execution
+
+use crate::error::Result;
+use crate::physical_plan::window_functions::PartitionEvaluator;
+use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr,
PhysicalExpr};
+use arrow::array::ArrayRef;
+use arrow::array::UInt64Array;
+use arrow::datatypes::{DataType, Field};
+use arrow::record_batch::RecordBatch;
+use std::any::Any;
+use std::iter;
+use std::ops::Range;
+use std::sync::Arc;
+
+/// Rank calculates the rank in the window function with order by
+#[derive(Debug)]
+pub struct Rank {
+ name: String,
+ dense: bool,
+}
+
+/// Create a rank window function
+pub fn rank(name: String) -> Rank {
+ Rank { name, dense: false }
+}
+
+/// Create a dense rank window function
+pub fn dense_rank(name: String) -> Rank {
+ Rank { name, dense: true }
+}
+
+impl BuiltInWindowFunctionExpr for Rank {
+ /// Return a reference to Any that can be used for downcasting
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn field(&self) -> Result<Field> {
+ let nullable = false;
+ let data_type = DataType::UInt64;
+ Ok(Field::new(self.name(), data_type, nullable))
+ }
+
+ fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+ vec![]
+ }
+
+ fn name(&self) -> &str {
+ &self.name
+ }
+
+ fn create_evaluator(
+ &self,
+ _batch: &RecordBatch,
+ ) -> Result<Box<dyn PartitionEvaluator>> {
+ Ok(Box::new(RankEvaluator { dense: self.dense }))
+ }
+}
+
+pub(crate) struct RankEvaluator {
+ dense: bool,
+}
+
+impl PartitionEvaluator for RankEvaluator {
+ fn include_rank(&self) -> bool {
+ true
+ }
+
+ fn evaluate_partition(&self, _partition: Range<usize>) -> Result<ArrayRef>
{
+ unreachable!("rank evaluation must be called with
evaluate_partition_with_rank")
+ }
+
+ fn evaluate_partition_with_rank(
+ &self,
+ _partition: Range<usize>,
+ ranks_in_partition: &[Range<usize>],
+ ) -> Result<ArrayRef> {
+ let result = if self.dense {
+
UInt64Array::from_iter_values(ranks_in_partition.iter().zip(1u64..).flat_map(
+ |(range, rank)| {
+ let len = range.end - range.start;
+ iter::repeat(rank).take(len)
+ },
+ ))
+ } else {
+ UInt64Array::from_iter_values(
+ ranks_in_partition
+ .iter()
+ .scan(1_u64, |acc, range| {
+ let len = range.end - range.start;
+ let result = iter::repeat(*acc).take(len);
+ *acc += len as u64;
+ Some(result)
+ })
+ .flatten(),
+ )
+ };
+ Ok(Arc::new(result))
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use arrow::{array::*, datatypes::*};
+
+ fn test_with_rank(expr: &Rank, expected: Vec<u64>) -> Result<()> {
+ test_i32_result(
+ expr,
+ vec![-2, -2, 1, 3, 3, 3, 7, 8],
+ vec![0..2, 2..3, 3..6, 6..7, 7..8],
+ expected,
+ )
+ }
+
+ fn test_without_rank(expr: &Rank, expected: Vec<u64>) -> Result<()> {
+ test_i32_result(expr, vec![-2, -2, 1, 3, 3, 3, 7, 8], vec![0..8],
expected)
+ }
+
+ fn test_i32_result(
+ expr: &Rank,
+ data: Vec<i32>,
+ ranks: Vec<Range<usize>>,
+ expected: Vec<u64>,
+ ) -> Result<()> {
+ let arr: ArrayRef = Arc::new(Int32Array::from(data));
+ let values = vec![arr];
+ let schema = Schema::new(vec![Field::new("arr", DataType::Int32,
false)]);
+ let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?;
+ let result = expr
+ .create_evaluator(&batch)?
+ .evaluate_with_rank(vec![0..8], ranks)?;
+ assert_eq!(1, result.len());
+ let result = result[0].as_any().downcast_ref::<UInt64Array>().unwrap();
+ let result = result.values();
+ assert_eq!(expected, result);
+ Ok(())
+ }
+
+ #[test]
+ fn test_dense_rank() -> Result<()> {
+ let r = dense_rank("arr".into());
+ test_without_rank(&r, vec![1; 8])?;
+ test_with_rank(&r, vec![1, 1, 2, 3, 3, 3, 4, 5])?;
+ Ok(())
+ }
+
+ #[test]
+ fn test_rank() -> Result<()> {
+ let r = rank("arr".into());
+ test_without_rank(&r, vec![1; 8])?;
+ test_with_rank(&r, vec![1, 1, 3, 4, 4, 4, 7, 8])?;
+ Ok(())
+ }
+}
diff --git a/datafusion/src/physical_plan/expressions/row_number.rs
b/datafusion/src/physical_plan/expressions/row_number.rs
index 6b488cc..c65945f 100644
--- a/datafusion/src/physical_plan/expressions/row_number.rs
+++ b/datafusion/src/physical_plan/expressions/row_number.rs
@@ -18,10 +18,13 @@
//! Defines physical expression for `row_number` that can evaluated at runtime
during query execution
use crate::error::Result;
+use crate::physical_plan::window_functions::PartitionEvaluator;
use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr,
PhysicalExpr};
use arrow::array::{ArrayRef, UInt64Array};
use arrow::datatypes::{DataType, Field};
+use arrow::record_batch::RecordBatch;
use std::any::Any;
+use std::ops::Range;
use std::sync::Arc;
/// row_number expression
@@ -54,12 +57,25 @@ impl BuiltInWindowFunctionExpr for RowNumber {
}
fn name(&self) -> &str {
- self.name.as_str()
+ &self.name
}
- fn evaluate(&self, num_rows: usize, _values: &[ArrayRef]) ->
Result<ArrayRef> {
+ fn create_evaluator(
+ &self,
+ _batch: &RecordBatch,
+ ) -> Result<Box<dyn PartitionEvaluator>> {
+ Ok(Box::new(NumRowsEvaluator::default()))
+ }
+}
+
+#[derive(Default)]
+pub(crate) struct NumRowsEvaluator {}
+
+impl PartitionEvaluator for NumRowsEvaluator {
+ fn evaluate_partition(&self, partition: Range<usize>) -> Result<ArrayRef> {
+ let num_rows = partition.end - partition.start;
Ok(Arc::new(UInt64Array::from_iter_values(
- (1..num_rows + 1).map(|i| i as u64),
+ 1..(num_rows as u64) + 1,
)))
}
}
@@ -79,8 +95,9 @@ mod tests {
let schema = Schema::new(vec![Field::new("arr", DataType::Boolean,
false)]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![arr])?;
let row_number = RowNumber::new("row_number".to_owned());
- let result = row_number.evaluate(batch.num_rows(), &[])?;
- let result = result.as_any().downcast_ref::<UInt64Array>().unwrap();
+ let result =
row_number.create_evaluator(&batch)?.evaluate(vec![0..8])?;
+ assert_eq!(1, result.len());
+ let result = result[0].as_any().downcast_ref::<UInt64Array>().unwrap();
let result = result.values();
assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], result);
Ok(())
@@ -94,8 +111,9 @@ mod tests {
let schema = Schema::new(vec![Field::new("arr", DataType::Boolean,
false)]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![arr])?;
let row_number = RowNumber::new("row_number".to_owned());
- let result = row_number.evaluate(batch.num_rows(), &[])?;
- let result = result.as_any().downcast_ref::<UInt64Array>().unwrap();
+ let result =
row_number.create_evaluator(&batch)?.evaluate(vec![0..8])?;
+ assert_eq!(1, result.len());
+ let result = result[0].as_any().downcast_ref::<UInt64Array>().unwrap();
let result = result.values();
assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], result);
Ok(())
diff --git a/datafusion/src/physical_plan/window_functions.rs
b/datafusion/src/physical_plan/window_functions.rs
index 4f56aa7..99805b6 100644
--- a/datafusion/src/physical_plan/window_functions.rs
+++ b/datafusion/src/physical_plan/window_functions.rs
@@ -20,15 +20,17 @@
//!
//! see also https://www.postgresql.org/docs/current/functions-window.html
-use crate::arrow::array::ArrayRef;
-use crate::arrow::datatypes::Field;
use crate::error::{DataFusionError, Result};
use crate::physical_plan::{
aggregates, aggregates::AggregateFunction, functions::Signature,
- type_coercion::data_types, PhysicalExpr,
+ type_coercion::data_types, windows::find_ranges_in_range, PhysicalExpr,
};
+use arrow::array::ArrayRef;
use arrow::datatypes::DataType;
+use arrow::datatypes::Field;
+use arrow::record_batch::RecordBatch;
use std::any::Any;
+use std::ops::Range;
use std::sync::Arc;
use std::{fmt, str::FromStr};
@@ -208,11 +210,57 @@ pub(super) fn signature_for_built_in(fun:
&BuiltInWindowFunction) -> Signature {
}
}
+/// Partition evaluator
+pub(crate) trait PartitionEvaluator {
+ /// Whether the evaluator should be evaluated with rank
+ fn include_rank(&self) -> bool {
+ false
+ }
+
+ /// evaluate the partition evaluator against the partitions
+ fn evaluate(&self, partition_points: Vec<Range<usize>>) ->
Result<Vec<ArrayRef>> {
+ partition_points
+ .into_iter()
+ .map(|partition| self.evaluate_partition(partition))
+ .collect()
+ }
+
+ /// evaluate the partition evaluator against the partitions with rank
information
+ fn evaluate_with_rank(
+ &self,
+ partition_points: Vec<Range<usize>>,
+ sort_partition_points: Vec<Range<usize>>,
+ ) -> Result<Vec<ArrayRef>> {
+ partition_points
+ .into_iter()
+ .map(|partition| {
+ let ranks_in_partition =
+ find_ranges_in_range(&partition, &sort_partition_points);
+ self.evaluate_partition_with_rank(partition,
ranks_in_partition)
+ })
+ .collect()
+ }
+
+ /// evaluate the partition evaluator against the partition
+ fn evaluate_partition(&self, _partition: Range<usize>) -> Result<ArrayRef>;
+
+ /// evaluate the partition evaluator against the partition but with rank
+ fn evaluate_partition_with_rank(
+ &self,
+ _partition: Range<usize>,
+ _ranks_in_partition: &[Range<usize>],
+ ) -> Result<ArrayRef> {
+ Err(DataFusionError::NotImplemented(
+ "evaluate_partition_with_rank is not implemented by
default".into(),
+ ))
+ }
+}
+
/// A window expression that is a built-in window function.
///
/// Note that unlike aggregation based window functions, built-in window
functions normally ignore
/// window frame spec, with the exception of first_value, last_value, and
nth_value.
-pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug {
+pub(crate) trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug {
/// Returns the aggregate expression as [`Any`](std::any::Any) so that it
can be
/// downcast to a specific implementation.
fn as_any(&self) -> &dyn Any;
@@ -230,8 +278,11 @@ pub trait BuiltInWindowFunctionExpr: Send + Sync +
std::fmt::Debug {
"BuiltInWindowFunctionExpr: default name"
}
- /// Evaluate the built-in window function against the number of rows and
the arguments
- fn evaluate(&self, num_rows: usize, values: &[ArrayRef]) ->
Result<ArrayRef>;
+ /// Create built-in window evaluator with a batch
+ fn create_evaluator(
+ &self,
+ batch: &RecordBatch,
+ ) -> Result<Box<dyn PartitionEvaluator>>;
}
#[cfg(test)]
diff --git a/datafusion/src/physical_plan/windows.rs
b/datafusion/src/physical_plan/windows.rs
index 2f53905..8926376 100644
--- a/datafusion/src/physical_plan/windows.rs
+++ b/datafusion/src/physical_plan/windows.rs
@@ -21,11 +21,12 @@ use crate::error::{DataFusionError, Result};
use crate::logical_plan::window_frames::{WindowFrame, WindowFrameUnits};
use crate::physical_plan::{
aggregates, common,
- expressions::{Literal, NthValue, PhysicalSortExpr, RowNumber},
+ expressions::{dense_rank, rank, Literal, NthValue, PhysicalSortExpr,
RowNumber},
type_coercion::coerce,
- window_functions::signature_for_built_in,
- window_functions::BuiltInWindowFunctionExpr,
- window_functions::{BuiltInWindowFunction, WindowFunction},
+ window_functions::{
+ signature_for_built_in, BuiltInWindowFunction,
BuiltInWindowFunctionExpr,
+ WindowFunction,
+ },
Accumulator, AggregateExpr, Distribution, ExecutionPlan, Partitioning,
PhysicalExpr,
RecordBatchStream, SendableRecordBatchStream, WindowExpr,
};
@@ -84,7 +85,8 @@ pub fn create_window_expr(
window_frame,
}),
WindowFunction::BuiltInWindowFunction(fun) =>
Arc::new(BuiltInWindowExpr {
- window: create_built_in_window_expr(fun, args, input_schema,
name)?,
+ fun: fun.clone(),
+ expr: create_built_in_window_expr(fun, args, input_schema, name)?,
partition_by: partition_by.to_vec(),
order_by: order_by.to_vec(),
window_frame,
@@ -100,6 +102,8 @@ fn create_built_in_window_expr(
) -> Result<Arc<dyn BuiltInWindowFunctionExpr>> {
match fun {
BuiltInWindowFunction::RowNumber => Ok(Arc::new(RowNumber::new(name))),
+ BuiltInWindowFunction::Rank => Ok(Arc::new(rank(name))),
+ BuiltInWindowFunction::DenseRank => Ok(Arc::new(dense_rank(name))),
BuiltInWindowFunction::NthValue => {
let coerced_args = coerce(args, input_schema,
&signature_for_built_in(fun))?;
let arg = coerced_args[0].clone();
@@ -138,7 +142,8 @@ fn create_built_in_window_expr(
/// A window expr that takes the form of a built in window function
#[derive(Debug)]
pub struct BuiltInWindowExpr {
- window: Arc<dyn BuiltInWindowFunctionExpr>,
+ fun: BuiltInWindowFunction,
+ expr: Arc<dyn BuiltInWindowFunctionExpr>,
partition_by: Vec<Arc<dyn PhysicalExpr>>,
order_by: Vec<PhysicalSortExpr>,
window_frame: Option<WindowFrame>,
@@ -151,15 +156,15 @@ impl WindowExpr for BuiltInWindowExpr {
}
fn name(&self) -> &str {
- self.window.name()
+ self.expr.name()
}
fn field(&self) -> Result<Field> {
- self.window.field()
+ self.expr.field()
}
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
- self.window.expressions()
+ self.expr.expressions()
}
fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>] {
@@ -171,25 +176,17 @@ impl WindowExpr for BuiltInWindowExpr {
}
fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
- let values = self.evaluate_args(batch)?;
- let partition_points = self.evaluate_partition_points(
- batch.num_rows(),
- &self.partition_columns(batch)?,
- )?;
- let results = partition_points
- .iter()
- .map(|partition_range| {
- let start = partition_range.start;
- let len = partition_range.end - start;
- let values = values
- .iter()
- .map(|arr| arr.slice(start, len))
- .collect::<Vec<_>>();
- self.window.evaluate(len, &values)
- })
- .collect::<Result<Vec<_>>>()?
- .into_iter()
- .collect::<Vec<ArrayRef>>();
+ let evaluator = self.expr.create_evaluator(batch)?;
+ let num_rows = batch.num_rows();
+ let partition_points =
+ self.evaluate_partition_points(num_rows,
&self.partition_columns(batch)?)?;
+ let results = if evaluator.include_rank() {
+ let sort_partition_points =
+ self.evaluate_partition_points(num_rows,
&self.sort_columns(batch)?)?;
+ evaluator.evaluate_with_rank(partition_points,
sort_partition_points)?
+ } else {
+ evaluator.evaluate(partition_points)?
+ };
let results = results.iter().map(|i| i.as_ref()).collect::<Vec<_>>();
concat(&results).map_err(DataFusionError::ArrowError)
}
@@ -200,7 +197,7 @@ impl WindowExpr for BuiltInWindowExpr {
/// boundaries would align (what's sorted on [partition columns...] would
definitely be sorted
/// on finer columns), so this will use binary search to find ranges that are
within the
/// partition range and return the valid slice.
-fn find_ranges_in_range<'a>(
+pub(crate) fn find_ranges_in_range<'a>(
partition_range: &Range<usize>,
sort_partition_points: &'a [Range<usize>],
) -> &'a [Range<usize>] {