This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new d2b3d1c753 Rename `expr::window_function::WindowFunction` to
`WindowFunctionDefinition`, make structure consistent with ScalarFunction
(#8382)
d2b3d1c753 is described below
commit d2b3d1c7538b9fb7ab9cfc0c4c6a238b0dcd91e6
Author: Edmondo Porcu <[email protected]>
AuthorDate: Mon Jan 1 14:09:41 2024 -0500
Rename `expr::window_function::WindowFunction` to
`WindowFunctionDefinition`, make structure consistent with ScalarFunction
(#8382)
* Refactoring WindowFunction into coherent structure with AggregateFunction
* One more cargo fmt
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/core/src/dataframe/mod.rs | 6 +-
.../core/src/physical_optimizer/test_utils.rs | 4 +-
datafusion/core/tests/dataframe/mod.rs | 4 +-
datafusion/core/tests/fuzz_cases/window_fuzz.rs | 46 +-
datafusion/expr/src/built_in_window_function.rs | 207 +++++++++
datafusion/expr/src/expr.rs | 291 ++++++++++++-
datafusion/expr/src/lib.rs | 6 +-
datafusion/expr/src/udwf.rs | 2 +-
datafusion/expr/src/utils.rs | 22 +-
datafusion/expr/src/window_function.rs | 483 ---------------------
.../optimizer/src/analyzer/count_wildcard_rule.rs | 10 +-
datafusion/optimizer/src/analyzer/type_coercion.rs | 8 +-
datafusion/optimizer/src/push_down_projection.rs | 6 +-
datafusion/physical-plan/src/windows/mod.rs | 28 +-
datafusion/proto/src/logical_plan/from_proto.rs | 8 +-
datafusion/proto/src/logical_plan/to_proto.rs | 10 +-
datafusion/proto/src/physical_plan/from_proto.rs | 10 +-
.../proto/tests/cases/roundtrip_logical_plan.rs | 20 +-
datafusion/sql/src/expr/function.rs | 19 +-
datafusion/substrait/src/logical_plan/consumer.rs | 4 +-
20 files changed, 613 insertions(+), 581 deletions(-)
diff --git a/datafusion/core/src/dataframe/mod.rs
b/datafusion/core/src/dataframe/mod.rs
index 3c3bcd497b..5a8c706e32 100644
--- a/datafusion/core/src/dataframe/mod.rs
+++ b/datafusion/core/src/dataframe/mod.rs
@@ -1360,7 +1360,7 @@ mod tests {
use datafusion_expr::{
avg, cast, count, count_distinct, create_udf, expr, lit, max, min, sum,
BuiltInWindowFunction, ScalarFunctionImplementation, Volatility,
WindowFrame,
- WindowFunction,
+ WindowFunctionDefinition,
};
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_plan::get_plan_string;
@@ -1525,7 +1525,9 @@ mod tests {
// build plan using Table API
let t = test_table().await?;
let first_row = Expr::WindowFunction(expr::WindowFunction::new(
-
WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue),
+ WindowFunctionDefinition::BuiltInWindowFunction(
+ BuiltInWindowFunction::FirstValue,
+ ),
vec![col("aggregate_test_100.c1")],
vec![col("aggregate_test_100.c2")],
vec![],
diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs
b/datafusion/core/src/physical_optimizer/test_utils.rs
index 6e14cca21f..debafefe39 100644
--- a/datafusion/core/src/physical_optimizer/test_utils.rs
+++ b/datafusion/core/src/physical_optimizer/test_utils.rs
@@ -41,7 +41,7 @@ use crate::prelude::{CsvReadOptions, SessionContext};
use arrow_schema::{Schema, SchemaRef, SortOptions};
use datafusion_common::{JoinType, Statistics};
use datafusion_execution::object_store::ObjectStoreUrl;
-use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunction};
+use datafusion_expr::{AggregateFunction, WindowFrame,
WindowFunctionDefinition};
use datafusion_physical_expr::expressions::col;
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
@@ -234,7 +234,7 @@ pub fn bounded_window_exec(
Arc::new(
crate::physical_plan::windows::BoundedWindowAggExec::try_new(
vec![create_window_expr(
- &WindowFunction::AggregateFunction(AggregateFunction::Count),
+
&WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count),
"count".to_owned(),
&[col(col_name, &schema).unwrap()],
&[],
diff --git a/datafusion/core/tests/dataframe/mod.rs
b/datafusion/core/tests/dataframe/mod.rs
index ba661aa244..cca23ac684 100644
--- a/datafusion/core/tests/dataframe/mod.rs
+++ b/datafusion/core/tests/dataframe/mod.rs
@@ -45,7 +45,7 @@ use datafusion_expr::expr::{GroupingSet, Sort};
use datafusion_expr::{
array_agg, avg, col, count, exists, expr, in_subquery, lit, max,
out_ref_col,
scalar_subquery, sum, wildcard, AggregateFunction, Expr, ExprSchemable,
WindowFrame,
- WindowFrameBound, WindowFrameUnits, WindowFunction,
+ WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
};
use datafusion_physical_expr::var_provider::{VarProvider, VarType};
@@ -170,7 +170,7 @@ async fn test_count_wildcard_on_window() -> Result<()> {
.table("t1")
.await?
.select(vec![Expr::WindowFunction(expr::WindowFunction::new(
- WindowFunction::AggregateFunction(AggregateFunction::Count),
+
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count),
vec![wildcard()],
vec![],
vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))],
diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs
b/datafusion/core/tests/fuzz_cases/window_fuzz.rs
index 44ff71d023..3037b4857a 100644
--- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs
@@ -33,7 +33,7 @@ use datafusion_common::{Result, ScalarValue};
use datafusion_expr::type_coercion::aggregates::coerce_types;
use datafusion_expr::{
AggregateFunction, BuiltInWindowFunction, WindowFrame, WindowFrameBound,
- WindowFrameUnits, WindowFunction,
+ WindowFrameUnits, WindowFunctionDefinition,
};
use datafusion_physical_expr::expressions::{cast, col, lit};
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
@@ -143,7 +143,7 @@ fn get_random_function(
schema: &SchemaRef,
rng: &mut StdRng,
is_linear: bool,
-) -> (WindowFunction, Vec<Arc<dyn PhysicalExpr>>, String) {
+) -> (WindowFunctionDefinition, Vec<Arc<dyn PhysicalExpr>>, String) {
let mut args = if is_linear {
// In linear test for the test version with WindowAggExec we use
insert SortExecs to the plan to be able to generate
// same result with BoundedWindowAggExec which doesn't use any
SortExec. To make result
@@ -159,28 +159,28 @@ fn get_random_function(
window_fn_map.insert(
"sum",
(
- WindowFunction::AggregateFunction(AggregateFunction::Sum),
+
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
vec![],
),
);
window_fn_map.insert(
"count",
(
- WindowFunction::AggregateFunction(AggregateFunction::Count),
+
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count),
vec![],
),
);
window_fn_map.insert(
"min",
(
- WindowFunction::AggregateFunction(AggregateFunction::Min),
+
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min),
vec![],
),
);
window_fn_map.insert(
"max",
(
- WindowFunction::AggregateFunction(AggregateFunction::Max),
+
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
vec![],
),
);
@@ -191,28 +191,36 @@ fn get_random_function(
window_fn_map.insert(
"row_number",
(
-
WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::RowNumber),
+ WindowFunctionDefinition::BuiltInWindowFunction(
+ BuiltInWindowFunction::RowNumber,
+ ),
vec![],
),
);
window_fn_map.insert(
"rank",
(
-
WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Rank),
+ WindowFunctionDefinition::BuiltInWindowFunction(
+ BuiltInWindowFunction::Rank,
+ ),
vec![],
),
);
window_fn_map.insert(
"dense_rank",
(
-
WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::DenseRank),
+ WindowFunctionDefinition::BuiltInWindowFunction(
+ BuiltInWindowFunction::DenseRank,
+ ),
vec![],
),
);
window_fn_map.insert(
"lead",
(
-
WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lead),
+ WindowFunctionDefinition::BuiltInWindowFunction(
+ BuiltInWindowFunction::Lead,
+ ),
vec![
lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))),
lit(ScalarValue::Int64(Some(rng.gen_range(1..1000)))),
@@ -222,7 +230,9 @@ fn get_random_function(
window_fn_map.insert(
"lag",
(
-
WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lag),
+ WindowFunctionDefinition::BuiltInWindowFunction(
+ BuiltInWindowFunction::Lag,
+ ),
vec![
lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))),
lit(ScalarValue::Int64(Some(rng.gen_range(1..1000)))),
@@ -233,21 +243,27 @@ fn get_random_function(
window_fn_map.insert(
"first_value",
(
-
WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue),
+ WindowFunctionDefinition::BuiltInWindowFunction(
+ BuiltInWindowFunction::FirstValue,
+ ),
vec![],
),
);
window_fn_map.insert(
"last_value",
(
-
WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::LastValue),
+ WindowFunctionDefinition::BuiltInWindowFunction(
+ BuiltInWindowFunction::LastValue,
+ ),
vec![],
),
);
window_fn_map.insert(
"nth_value",
(
-
WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::NthValue),
+ WindowFunctionDefinition::BuiltInWindowFunction(
+ BuiltInWindowFunction::NthValue,
+ ),
vec![lit(ScalarValue::Int64(Some(rng.gen_range(1..10))))],
),
);
@@ -255,7 +271,7 @@ fn get_random_function(
let rand_fn_idx = rng.gen_range(0..window_fn_map.len());
let fn_name = window_fn_map.keys().collect::<Vec<_>>()[rand_fn_idx];
let (window_fn, new_args) =
window_fn_map.values().collect::<Vec<_>>()[rand_fn_idx];
- if let WindowFunction::AggregateFunction(f) = window_fn {
+ if let WindowFunctionDefinition::AggregateFunction(f) = window_fn {
let a = args[0].clone();
let dt = a.data_type(schema.as_ref()).unwrap();
let sig = f.signature();
diff --git a/datafusion/expr/src/built_in_window_function.rs
b/datafusion/expr/src/built_in_window_function.rs
new file mode 100644
index 0000000000..a03e3d2d24
--- /dev/null
+++ b/datafusion/expr/src/built_in_window_function.rs
@@ -0,0 +1,207 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Built-in functions module contains all the built-in functions definitions.
+
+use std::fmt;
+use std::str::FromStr;
+
+use crate::type_coercion::functions::data_types;
+use crate::utils;
+use crate::{Signature, TypeSignature, Volatility};
+use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError,
Result};
+
+use arrow::datatypes::DataType;
+
+use strum_macros::EnumIter;
+
+impl fmt::Display for BuiltInWindowFunction {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(f, "{}", self.name())
+ }
+}
+
+/// A [window function] built in to DataFusion
+///
+/// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL)
+#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter)]
+pub enum BuiltInWindowFunction {
+ /// number of the current row within its partition, counting from 1
+ RowNumber,
+ /// rank of the current row with gaps; same as row_number of its first peer
+ Rank,
+ /// rank of the current row without gaps; this function counts peer groups
+ DenseRank,
+ /// relative rank of the current row: (rank - 1) / (total rows - 1)
+ PercentRank,
+ /// relative rank of the current row: (number of rows preceding or peer
with current row) / (total rows)
+ CumeDist,
+ /// integer ranging from 1 to the argument value, dividing the partition
as equally as possible
+ Ntile,
+ /// returns value evaluated at the row that is offset rows before the
current row within the partition;
+ /// if there is no such row, instead return default (which must be of the
same type as value).
+ /// Both offset and default are evaluated with respect to the current row.
+ /// If omitted, offset defaults to 1 and default to null
+ Lag,
+ /// returns value evaluated at the row that is offset rows after the
current row within the partition;
+ /// if there is no such row, instead return default (which must be of the
same type as value).
+ /// Both offset and default are evaluated with respect to the current row.
+ /// If omitted, offset defaults to 1 and default to null
+ Lead,
+ /// returns value evaluated at the row that is the first row of the window
frame
+ FirstValue,
+ /// returns value evaluated at the row that is the last row of the window
frame
+ LastValue,
+ /// returns value evaluated at the row that is the nth row of the window
frame (counting from 1); null if no such row
+ NthValue,
+}
+
+impl BuiltInWindowFunction {
+ fn name(&self) -> &str {
+ use BuiltInWindowFunction::*;
+ match self {
+ RowNumber => "ROW_NUMBER",
+ Rank => "RANK",
+ DenseRank => "DENSE_RANK",
+ PercentRank => "PERCENT_RANK",
+ CumeDist => "CUME_DIST",
+ Ntile => "NTILE",
+ Lag => "LAG",
+ Lead => "LEAD",
+ FirstValue => "FIRST_VALUE",
+ LastValue => "LAST_VALUE",
+ NthValue => "NTH_VALUE",
+ }
+ }
+}
+
+impl FromStr for BuiltInWindowFunction {
+ type Err = DataFusionError;
+ fn from_str(name: &str) -> Result<BuiltInWindowFunction> {
+ Ok(match name.to_uppercase().as_str() {
+ "ROW_NUMBER" => BuiltInWindowFunction::RowNumber,
+ "RANK" => BuiltInWindowFunction::Rank,
+ "DENSE_RANK" => BuiltInWindowFunction::DenseRank,
+ "PERCENT_RANK" => BuiltInWindowFunction::PercentRank,
+ "CUME_DIST" => BuiltInWindowFunction::CumeDist,
+ "NTILE" => BuiltInWindowFunction::Ntile,
+ "LAG" => BuiltInWindowFunction::Lag,
+ "LEAD" => BuiltInWindowFunction::Lead,
+ "FIRST_VALUE" => BuiltInWindowFunction::FirstValue,
+ "LAST_VALUE" => BuiltInWindowFunction::LastValue,
+ "NTH_VALUE" => BuiltInWindowFunction::NthValue,
+ _ => return plan_err!("There is no built-in window function named
{name}"),
+ })
+ }
+}
+
+/// Returns the datatype of the built-in window function
+impl BuiltInWindowFunction {
+ pub fn return_type(&self, input_expr_types: &[DataType]) ->
Result<DataType> {
+ // Note that this function *must* return the same type that the
respective physical expression returns
+ // or the execution panics.
+
+ // verify that this is a valid set of data types for this function
+ data_types(input_expr_types, &self.signature())
+ // original errors are all related to wrong function signature
+ // aggregate them for better error message
+ .map_err(|_| {
+ plan_datafusion_err!(
+ "{}",
+ utils::generate_signature_error_msg(
+ &format!("{self}"),
+ self.signature(),
+ input_expr_types,
+ )
+ )
+ })?;
+
+ match self {
+ BuiltInWindowFunction::RowNumber
+ | BuiltInWindowFunction::Rank
+ | BuiltInWindowFunction::DenseRank => Ok(DataType::UInt64),
+ BuiltInWindowFunction::PercentRank |
BuiltInWindowFunction::CumeDist => {
+ Ok(DataType::Float64)
+ }
+ BuiltInWindowFunction::Ntile => Ok(DataType::UInt64),
+ BuiltInWindowFunction::Lag
+ | BuiltInWindowFunction::Lead
+ | BuiltInWindowFunction::FirstValue
+ | BuiltInWindowFunction::LastValue
+ | BuiltInWindowFunction::NthValue =>
Ok(input_expr_types[0].clone()),
+ }
+ }
+
+ /// the signatures supported by the built-in window function `fun`.
+ pub fn signature(&self) -> Signature {
+ // note: the physical expression must accept the type returned by this
function or the execution panics.
+ match self {
+ BuiltInWindowFunction::RowNumber
+ | BuiltInWindowFunction::Rank
+ | BuiltInWindowFunction::DenseRank
+ | BuiltInWindowFunction::PercentRank
+ | BuiltInWindowFunction::CumeDist => Signature::any(0,
Volatility::Immutable),
+ BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead => {
+ Signature::one_of(
+ vec![
+ TypeSignature::Any(1),
+ TypeSignature::Any(2),
+ TypeSignature::Any(3),
+ ],
+ Volatility::Immutable,
+ )
+ }
+ BuiltInWindowFunction::FirstValue |
BuiltInWindowFunction::LastValue => {
+ Signature::any(1, Volatility::Immutable)
+ }
+ BuiltInWindowFunction::Ntile => Signature::uniform(
+ 1,
+ vec![
+ DataType::UInt64,
+ DataType::UInt32,
+ DataType::UInt16,
+ DataType::UInt8,
+ DataType::Int64,
+ DataType::Int32,
+ DataType::Int16,
+ DataType::Int8,
+ ],
+ Volatility::Immutable,
+ ),
+ BuiltInWindowFunction::NthValue => Signature::any(2,
Volatility::Immutable),
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use strum::IntoEnumIterator;
+ #[test]
+ // Test for BuiltInWindowFunction's Display and from_str() implementations.
+ // For each variant in BuiltInWindowFunction, it converts the variant to a
string
+ // and then back to a variant. The test asserts that the original variant
and
+ // the reconstructed variant are the same. This assertion is also
necessary for
+ // function suggestion. See
https://github.com/apache/arrow-datafusion/issues/8082
+ fn test_display_and_from_str() {
+ for func_original in BuiltInWindowFunction::iter() {
+ let func_name = func_original.to_string();
+ let func_from_str =
BuiltInWindowFunction::from_str(&func_name).unwrap();
+ assert_eq!(func_from_str, func_original);
+ }
+ }
+}
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 0ec19bcadb..ebf4d3143c 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -19,13 +19,13 @@
use crate::expr_fn::binary_expr;
use crate::logical_plan::Subquery;
-use crate::udaf;
use crate::utils::{expr_to_columns, find_out_reference_exprs};
use crate::window_frame;
-use crate::window_function;
+
use crate::Operator;
use crate::{aggregate_function, ExprSchemable};
use crate::{built_in_function, BuiltinScalarFunction};
+use crate::{built_in_window_function, udaf};
use arrow::datatypes::DataType;
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{internal_err, DFSchema, OwnedTableReference};
@@ -34,8 +34,11 @@ use std::collections::HashSet;
use std::fmt;
use std::fmt::{Display, Formatter, Write};
use std::hash::{BuildHasher, Hash, Hasher};
+use std::str::FromStr;
use std::sync::Arc;
+use crate::Signature;
+
/// `Expr` is a central struct of DataFusion's query API, and
/// represent logical expressions such as `A + 1`, or `CAST(c1 AS
/// int)`.
@@ -566,11 +569,64 @@ impl AggregateFunction {
}
}
+/// WindowFunction
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+/// Defines which implementation of an aggregate function DataFusion should
call.
+pub enum WindowFunctionDefinition {
+ /// A built in aggregate function that leverages an aggregate function
+ AggregateFunction(aggregate_function::AggregateFunction),
+ /// A a built-in window function
+ BuiltInWindowFunction(built_in_window_function::BuiltInWindowFunction),
+ /// A user defined aggregate function
+ AggregateUDF(Arc<crate::AggregateUDF>),
+ /// A user defined aggregate function
+ WindowUDF(Arc<crate::WindowUDF>),
+}
+
+impl WindowFunctionDefinition {
+ /// Returns the datatype of the window function
+ pub fn return_type(&self, input_expr_types: &[DataType]) ->
Result<DataType> {
+ match self {
+ WindowFunctionDefinition::AggregateFunction(fun) => {
+ fun.return_type(input_expr_types)
+ }
+ WindowFunctionDefinition::BuiltInWindowFunction(fun) => {
+ fun.return_type(input_expr_types)
+ }
+ WindowFunctionDefinition::AggregateUDF(fun) => {
+ fun.return_type(input_expr_types)
+ }
+ WindowFunctionDefinition::WindowUDF(fun) =>
fun.return_type(input_expr_types),
+ }
+ }
+
+ /// the signatures supported by the function `fun`.
+ pub fn signature(&self) -> Signature {
+ match self {
+ WindowFunctionDefinition::AggregateFunction(fun) =>
fun.signature(),
+ WindowFunctionDefinition::BuiltInWindowFunction(fun) =>
fun.signature(),
+ WindowFunctionDefinition::AggregateUDF(fun) =>
fun.signature().clone(),
+ WindowFunctionDefinition::WindowUDF(fun) =>
fun.signature().clone(),
+ }
+ }
+}
+
+impl fmt::Display for WindowFunctionDefinition {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ match self {
+ WindowFunctionDefinition::AggregateFunction(fun) => fun.fmt(f),
+ WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.fmt(f),
+ WindowFunctionDefinition::AggregateUDF(fun) =>
std::fmt::Debug::fmt(fun, f),
+ WindowFunctionDefinition::WindowUDF(fun) => fun.fmt(f),
+ }
+ }
+}
+
/// Window function
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct WindowFunction {
/// Name of the function
- pub fun: window_function::WindowFunction,
+ pub fun: WindowFunctionDefinition,
/// List of expressions to feed to the functions as arguments
pub args: Vec<Expr>,
/// List of partition by expressions
@@ -584,7 +640,7 @@ pub struct WindowFunction {
impl WindowFunction {
/// Create a new Window expression
pub fn new(
- fun: window_function::WindowFunction,
+ fun: WindowFunctionDefinition,
args: Vec<Expr>,
partition_by: Vec<Expr>,
order_by: Vec<Expr>,
@@ -600,6 +656,50 @@ impl WindowFunction {
}
}
+/// Find DataFusion's built-in window function by name.
+pub fn find_df_window_func(name: &str) -> Option<WindowFunctionDefinition> {
+ let name = name.to_lowercase();
+ // Code paths for window functions leveraging ordinary aggregators and
+ // built-in window functions are quite different, and the same function
+ // may have different implementations for these cases. If the sought
+ // function is not found among built-in window functions, we search for
+ // it among aggregate functions.
+ if let Ok(built_in_function) =
+
built_in_window_function::BuiltInWindowFunction::from_str(name.as_str())
+ {
+ Some(WindowFunctionDefinition::BuiltInWindowFunction(
+ built_in_function,
+ ))
+ } else if let Ok(aggregate) =
+ aggregate_function::AggregateFunction::from_str(name.as_str())
+ {
+ Some(WindowFunctionDefinition::AggregateFunction(aggregate))
+ } else {
+ None
+ }
+}
+
+/// Returns the datatype of the window function
+#[deprecated(
+ since = "27.0.0",
+ note = "please use `WindowFunction::return_type` instead"
+)]
+pub fn return_type(
+ fun: &WindowFunctionDefinition,
+ input_expr_types: &[DataType],
+) -> Result<DataType> {
+ fun.return_type(input_expr_types)
+}
+
+/// the signatures supported by the function `fun`.
+#[deprecated(
+ since = "27.0.0",
+ note = "please use `WindowFunction::signature` instead"
+)]
+pub fn signature(fun: &WindowFunctionDefinition) -> Signature {
+ fun.signature()
+}
+
// Exists expression.
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct Exists {
@@ -1890,4 +1990,187 @@ mod test {
.is_volatile()
.expect_err("Shouldn't determine volatility of unresolved
function");
}
+
+ use super::*;
+
+ #[test]
+ fn test_count_return_type() -> Result<()> {
+ let fun = find_df_window_func("count").unwrap();
+ let observed = fun.return_type(&[DataType::Utf8])?;
+ assert_eq!(DataType::Int64, observed);
+
+ let observed = fun.return_type(&[DataType::UInt64])?;
+ assert_eq!(DataType::Int64, observed);
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_first_value_return_type() -> Result<()> {
+ let fun = find_df_window_func("first_value").unwrap();
+ let observed = fun.return_type(&[DataType::Utf8])?;
+ assert_eq!(DataType::Utf8, observed);
+
+ let observed = fun.return_type(&[DataType::UInt64])?;
+ assert_eq!(DataType::UInt64, observed);
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_last_value_return_type() -> Result<()> {
+ let fun = find_df_window_func("last_value").unwrap();
+ let observed = fun.return_type(&[DataType::Utf8])?;
+ assert_eq!(DataType::Utf8, observed);
+
+ let observed = fun.return_type(&[DataType::Float64])?;
+ assert_eq!(DataType::Float64, observed);
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_lead_return_type() -> Result<()> {
+ let fun = find_df_window_func("lead").unwrap();
+ let observed = fun.return_type(&[DataType::Utf8])?;
+ assert_eq!(DataType::Utf8, observed);
+
+ let observed = fun.return_type(&[DataType::Float64])?;
+ assert_eq!(DataType::Float64, observed);
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_lag_return_type() -> Result<()> {
+ let fun = find_df_window_func("lag").unwrap();
+ let observed = fun.return_type(&[DataType::Utf8])?;
+ assert_eq!(DataType::Utf8, observed);
+
+ let observed = fun.return_type(&[DataType::Float64])?;
+ assert_eq!(DataType::Float64, observed);
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_nth_value_return_type() -> Result<()> {
+ let fun = find_df_window_func("nth_value").unwrap();
+ let observed = fun.return_type(&[DataType::Utf8, DataType::UInt64])?;
+ assert_eq!(DataType::Utf8, observed);
+
+ let observed = fun.return_type(&[DataType::Float64,
DataType::UInt64])?;
+ assert_eq!(DataType::Float64, observed);
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_percent_rank_return_type() -> Result<()> {
+ let fun = find_df_window_func("percent_rank").unwrap();
+ let observed = fun.return_type(&[])?;
+ assert_eq!(DataType::Float64, observed);
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_cume_dist_return_type() -> Result<()> {
+ let fun = find_df_window_func("cume_dist").unwrap();
+ let observed = fun.return_type(&[])?;
+ assert_eq!(DataType::Float64, observed);
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_ntile_return_type() -> Result<()> {
+ let fun = find_df_window_func("ntile").unwrap();
+ let observed = fun.return_type(&[DataType::Int16])?;
+ assert_eq!(DataType::UInt64, observed);
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_window_function_case_insensitive() -> Result<()> {
+ let names = vec![
+ "row_number",
+ "rank",
+ "dense_rank",
+ "percent_rank",
+ "cume_dist",
+ "ntile",
+ "lag",
+ "lead",
+ "first_value",
+ "last_value",
+ "nth_value",
+ "min",
+ "max",
+ "count",
+ "avg",
+ "sum",
+ ];
+ for name in names {
+ let fun = find_df_window_func(name).unwrap();
+ let fun2 =
find_df_window_func(name.to_uppercase().as_str()).unwrap();
+ assert_eq!(fun, fun2);
+ assert_eq!(fun.to_string(), name.to_uppercase());
+ }
+ Ok(())
+ }
+
+ #[test]
+ fn test_find_df_window_function() {
+ assert_eq!(
+ find_df_window_func("max"),
+ Some(WindowFunctionDefinition::AggregateFunction(
+ aggregate_function::AggregateFunction::Max
+ ))
+ );
+ assert_eq!(
+ find_df_window_func("min"),
+ Some(WindowFunctionDefinition::AggregateFunction(
+ aggregate_function::AggregateFunction::Min
+ ))
+ );
+ assert_eq!(
+ find_df_window_func("avg"),
+ Some(WindowFunctionDefinition::AggregateFunction(
+ aggregate_function::AggregateFunction::Avg
+ ))
+ );
+ assert_eq!(
+ find_df_window_func("cume_dist"),
+ Some(WindowFunctionDefinition::BuiltInWindowFunction(
+ built_in_window_function::BuiltInWindowFunction::CumeDist
+ ))
+ );
+ assert_eq!(
+ find_df_window_func("first_value"),
+ Some(WindowFunctionDefinition::BuiltInWindowFunction(
+ built_in_window_function::BuiltInWindowFunction::FirstValue
+ ))
+ );
+ assert_eq!(
+ find_df_window_func("LAST_value"),
+ Some(WindowFunctionDefinition::BuiltInWindowFunction(
+ built_in_window_function::BuiltInWindowFunction::LastValue
+ ))
+ );
+ assert_eq!(
+ find_df_window_func("LAG"),
+ Some(WindowFunctionDefinition::BuiltInWindowFunction(
+ built_in_window_function::BuiltInWindowFunction::Lag
+ ))
+ );
+ assert_eq!(
+ find_df_window_func("LEAD"),
+ Some(WindowFunctionDefinition::BuiltInWindowFunction(
+ built_in_window_function::BuiltInWindowFunction::Lead
+ ))
+ );
+ assert_eq!(find_df_window_func("not_exist"), None)
+ }
}
diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs
index bf8e9e2954..ab213a19a3 100644
--- a/datafusion/expr/src/lib.rs
+++ b/datafusion/expr/src/lib.rs
@@ -27,6 +27,7 @@
mod accumulator;
mod built_in_function;
+mod built_in_window_function;
mod columnar_value;
mod literal;
mod nullif;
@@ -53,16 +54,16 @@ pub mod tree_node;
pub mod type_coercion;
pub mod utils;
pub mod window_frame;
-pub mod window_function;
pub mod window_state;
pub use accumulator::Accumulator;
pub use aggregate_function::AggregateFunction;
pub use built_in_function::BuiltinScalarFunction;
+pub use built_in_window_function::BuiltInWindowFunction;
pub use columnar_value::ColumnarValue;
pub use expr::{
Between, BinaryExpr, Case, Cast, Expr, GetFieldAccess, GetIndexedField,
GroupingSet,
- Like, ScalarFunctionDefinition, TryCast,
+ Like, ScalarFunctionDefinition, TryCast, WindowFunctionDefinition,
};
pub use expr_fn::*;
pub use expr_schema::ExprSchemable;
@@ -83,7 +84,6 @@ pub use udaf::AggregateUDF;
pub use udf::{ScalarUDF, ScalarUDFImpl};
pub use udwf::WindowUDF;
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};
-pub use window_function::{BuiltInWindowFunction, WindowFunction};
#[cfg(test)]
#[ctor::ctor]
diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs
index c233ee84b3..a97a68341f 100644
--- a/datafusion/expr/src/udwf.rs
+++ b/datafusion/expr/src/udwf.rs
@@ -107,7 +107,7 @@ impl WindowUDF {
order_by: Vec<Expr>,
window_frame: WindowFrame,
) -> Expr {
- let fun = crate::WindowFunction::WindowUDF(Arc::new(self.clone()));
+ let fun =
crate::WindowFunctionDefinition::WindowUDF(Arc::new(self.clone()));
Expr::WindowFunction(crate::expr::WindowFunction {
fun,
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index 09f4842c9e..e3ecdf154e 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -1234,7 +1234,7 @@ mod tests {
use super::*;
use crate::{
col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup,
AggregateFunction,
- WindowFrame, WindowFunction,
+ WindowFrame, WindowFunctionDefinition,
};
#[test]
@@ -1248,28 +1248,28 @@ mod tests {
#[test]
fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> {
let max1 = Expr::WindowFunction(expr::WindowFunction::new(
- WindowFunction::AggregateFunction(AggregateFunction::Max),
+
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
vec![col("name")],
vec![],
vec![],
WindowFrame::new(false),
));
let max2 = Expr::WindowFunction(expr::WindowFunction::new(
- WindowFunction::AggregateFunction(AggregateFunction::Max),
+
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
vec![col("name")],
vec![],
vec![],
WindowFrame::new(false),
));
let min3 = Expr::WindowFunction(expr::WindowFunction::new(
- WindowFunction::AggregateFunction(AggregateFunction::Min),
+
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min),
vec![col("name")],
vec![],
vec![],
WindowFrame::new(false),
));
let sum4 = Expr::WindowFunction(expr::WindowFunction::new(
- WindowFunction::AggregateFunction(AggregateFunction::Sum),
+
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
vec![col("age")],
vec![],
vec![],
@@ -1291,28 +1291,28 @@ mod tests {
let created_at_desc =
Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false,
true));
let max1 = Expr::WindowFunction(expr::WindowFunction::new(
- WindowFunction::AggregateFunction(AggregateFunction::Max),
+
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
vec![col("name")],
vec![],
vec![age_asc.clone(), name_desc.clone()],
WindowFrame::new(true),
));
let max2 = Expr::WindowFunction(expr::WindowFunction::new(
- WindowFunction::AggregateFunction(AggregateFunction::Max),
+
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
vec![col("name")],
vec![],
vec![],
WindowFrame::new(false),
));
let min3 = Expr::WindowFunction(expr::WindowFunction::new(
- WindowFunction::AggregateFunction(AggregateFunction::Min),
+
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min),
vec![col("name")],
vec![],
vec![age_asc.clone(), name_desc.clone()],
WindowFrame::new(true),
));
let sum4 = Expr::WindowFunction(expr::WindowFunction::new(
- WindowFunction::AggregateFunction(AggregateFunction::Sum),
+
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
vec![col("age")],
vec![],
vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()],
@@ -1343,7 +1343,7 @@ mod tests {
fn test_find_sort_exprs() -> Result<()> {
let exprs = &[
Expr::WindowFunction(expr::WindowFunction::new(
- WindowFunction::AggregateFunction(AggregateFunction::Max),
+
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
vec![col("name")],
vec![],
vec![
@@ -1353,7 +1353,7 @@ mod tests {
WindowFrame::new(true),
)),
Expr::WindowFunction(expr::WindowFunction::new(
- WindowFunction::AggregateFunction(AggregateFunction::Sum),
+
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
vec![col("age")],
vec![],
vec![
diff --git a/datafusion/expr/src/window_function.rs
b/datafusion/expr/src/window_function.rs
deleted file mode 100644
index 610f1ecaea..0000000000
--- a/datafusion/expr/src/window_function.rs
+++ /dev/null
@@ -1,483 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-//! Window functions provide the ability to perform calculations across
-//! sets of rows that are related to the current query row.
-//!
-//! see also <https://www.postgresql.org/docs/current/functions-window.html>
-
-use crate::aggregate_function::AggregateFunction;
-use crate::type_coercion::functions::data_types;
-use crate::utils;
-use crate::{AggregateUDF, Signature, TypeSignature, Volatility, WindowUDF};
-use arrow::datatypes::DataType;
-use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError,
Result};
-use std::sync::Arc;
-use std::{fmt, str::FromStr};
-use strum_macros::EnumIter;
-
-/// WindowFunction
-#[derive(Debug, Clone, PartialEq, Eq, Hash)]
-pub enum WindowFunction {
- /// A built in aggregate function that leverages an aggregate function
- AggregateFunction(AggregateFunction),
- /// A a built-in window function
- BuiltInWindowFunction(BuiltInWindowFunction),
- /// A user defined aggregate function
- AggregateUDF(Arc<AggregateUDF>),
- /// A user defined aggregate function
- WindowUDF(Arc<WindowUDF>),
-}
-
-/// Find DataFusion's built-in window function by name.
-pub fn find_df_window_func(name: &str) -> Option<WindowFunction> {
- let name = name.to_lowercase();
- // Code paths for window functions leveraging ordinary aggregators and
- // built-in window functions are quite different, and the same function
- // may have different implementations for these cases. If the sought
- // function is not found among built-in window functions, we search for
- // it among aggregate functions.
- if let Ok(built_in_function) =
BuiltInWindowFunction::from_str(name.as_str()) {
- Some(WindowFunction::BuiltInWindowFunction(built_in_function))
- } else if let Ok(aggregate) = AggregateFunction::from_str(name.as_str()) {
- Some(WindowFunction::AggregateFunction(aggregate))
- } else {
- None
- }
-}
-
-impl fmt::Display for BuiltInWindowFunction {
- fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
- write!(f, "{}", self.name())
- }
-}
-
-impl fmt::Display for WindowFunction {
- fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
- match self {
- WindowFunction::AggregateFunction(fun) => fun.fmt(f),
- WindowFunction::BuiltInWindowFunction(fun) => fun.fmt(f),
- WindowFunction::AggregateUDF(fun) => std::fmt::Debug::fmt(fun, f),
- WindowFunction::WindowUDF(fun) => fun.fmt(f),
- }
- }
-}
-
-/// A [window function] built in to DataFusion
-///
-/// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL)
-#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter)]
-pub enum BuiltInWindowFunction {
- /// number of the current row within its partition, counting from 1
- RowNumber,
- /// rank of the current row with gaps; same as row_number of its first peer
- Rank,
- /// rank of the current row without gaps; this function counts peer groups
- DenseRank,
- /// relative rank of the current row: (rank - 1) / (total rows - 1)
- PercentRank,
- /// relative rank of the current row: (number of rows preceding or peer
with current row) / (total rows)
- CumeDist,
- /// integer ranging from 1 to the argument value, dividing the partition
as equally as possible
- Ntile,
- /// returns value evaluated at the row that is offset rows before the
current row within the partition;
- /// if there is no such row, instead return default (which must be of the
same type as value).
- /// Both offset and default are evaluated with respect to the current row.
- /// If omitted, offset defaults to 1 and default to null
- Lag,
- /// returns value evaluated at the row that is offset rows after the
current row within the partition;
- /// if there is no such row, instead return default (which must be of the
same type as value).
- /// Both offset and default are evaluated with respect to the current row.
- /// If omitted, offset defaults to 1 and default to null
- Lead,
- /// returns value evaluated at the row that is the first row of the window
frame
- FirstValue,
- /// returns value evaluated at the row that is the last row of the window
frame
- LastValue,
- /// returns value evaluated at the row that is the nth row of the window
frame (counting from 1); null if no such row
- NthValue,
-}
-
-impl BuiltInWindowFunction {
- fn name(&self) -> &str {
- use BuiltInWindowFunction::*;
- match self {
- RowNumber => "ROW_NUMBER",
- Rank => "RANK",
- DenseRank => "DENSE_RANK",
- PercentRank => "PERCENT_RANK",
- CumeDist => "CUME_DIST",
- Ntile => "NTILE",
- Lag => "LAG",
- Lead => "LEAD",
- FirstValue => "FIRST_VALUE",
- LastValue => "LAST_VALUE",
- NthValue => "NTH_VALUE",
- }
- }
-}
-
-impl FromStr for BuiltInWindowFunction {
- type Err = DataFusionError;
- fn from_str(name: &str) -> Result<BuiltInWindowFunction> {
- Ok(match name.to_uppercase().as_str() {
- "ROW_NUMBER" => BuiltInWindowFunction::RowNumber,
- "RANK" => BuiltInWindowFunction::Rank,
- "DENSE_RANK" => BuiltInWindowFunction::DenseRank,
- "PERCENT_RANK" => BuiltInWindowFunction::PercentRank,
- "CUME_DIST" => BuiltInWindowFunction::CumeDist,
- "NTILE" => BuiltInWindowFunction::Ntile,
- "LAG" => BuiltInWindowFunction::Lag,
- "LEAD" => BuiltInWindowFunction::Lead,
- "FIRST_VALUE" => BuiltInWindowFunction::FirstValue,
- "LAST_VALUE" => BuiltInWindowFunction::LastValue,
- "NTH_VALUE" => BuiltInWindowFunction::NthValue,
- _ => return plan_err!("There is no built-in window function named
{name}"),
- })
- }
-}
-
-/// Returns the datatype of the window function
-#[deprecated(
- since = "27.0.0",
- note = "please use `WindowFunction::return_type` instead"
-)]
-pub fn return_type(
- fun: &WindowFunction,
- input_expr_types: &[DataType],
-) -> Result<DataType> {
- fun.return_type(input_expr_types)
-}
-
-impl WindowFunction {
- /// Returns the datatype of the window function
- pub fn return_type(&self, input_expr_types: &[DataType]) ->
Result<DataType> {
- match self {
- WindowFunction::AggregateFunction(fun) =>
fun.return_type(input_expr_types),
- WindowFunction::BuiltInWindowFunction(fun) => {
- fun.return_type(input_expr_types)
- }
- WindowFunction::AggregateUDF(fun) =>
fun.return_type(input_expr_types),
- WindowFunction::WindowUDF(fun) =>
fun.return_type(input_expr_types),
- }
- }
-}
-
-/// Returns the datatype of the built-in window function
-impl BuiltInWindowFunction {
- pub fn return_type(&self, input_expr_types: &[DataType]) ->
Result<DataType> {
- // Note that this function *must* return the same type that the
respective physical expression returns
- // or the execution panics.
-
- // verify that this is a valid set of data types for this function
- data_types(input_expr_types, &self.signature())
- // original errors are all related to wrong function signature
- // aggregate them for better error message
- .map_err(|_| {
- plan_datafusion_err!(
- "{}",
- utils::generate_signature_error_msg(
- &format!("{self}"),
- self.signature(),
- input_expr_types,
- )
- )
- })?;
-
- match self {
- BuiltInWindowFunction::RowNumber
- | BuiltInWindowFunction::Rank
- | BuiltInWindowFunction::DenseRank => Ok(DataType::UInt64),
- BuiltInWindowFunction::PercentRank |
BuiltInWindowFunction::CumeDist => {
- Ok(DataType::Float64)
- }
- BuiltInWindowFunction::Ntile => Ok(DataType::UInt64),
- BuiltInWindowFunction::Lag
- | BuiltInWindowFunction::Lead
- | BuiltInWindowFunction::FirstValue
- | BuiltInWindowFunction::LastValue
- | BuiltInWindowFunction::NthValue =>
Ok(input_expr_types[0].clone()),
- }
- }
-}
-
-/// the signatures supported by the function `fun`.
-#[deprecated(
- since = "27.0.0",
- note = "please use `WindowFunction::signature` instead"
-)]
-pub fn signature(fun: &WindowFunction) -> Signature {
- fun.signature()
-}
-
-impl WindowFunction {
- /// the signatures supported by the function `fun`.
- pub fn signature(&self) -> Signature {
- match self {
- WindowFunction::AggregateFunction(fun) => fun.signature(),
- WindowFunction::BuiltInWindowFunction(fun) => fun.signature(),
- WindowFunction::AggregateUDF(fun) => fun.signature().clone(),
- WindowFunction::WindowUDF(fun) => fun.signature().clone(),
- }
- }
-}
-
-/// the signatures supported by the built-in window function `fun`.
-#[deprecated(
- since = "27.0.0",
- note = "please use `BuiltInWindowFunction::signature` instead"
-)]
-pub fn signature_for_built_in(fun: &BuiltInWindowFunction) -> Signature {
- fun.signature()
-}
-
-impl BuiltInWindowFunction {
- /// the signatures supported by the built-in window function `fun`.
- pub fn signature(&self) -> Signature {
- // note: the physical expression must accept the type returned by this
function or the execution panics.
- match self {
- BuiltInWindowFunction::RowNumber
- | BuiltInWindowFunction::Rank
- | BuiltInWindowFunction::DenseRank
- | BuiltInWindowFunction::PercentRank
- | BuiltInWindowFunction::CumeDist => Signature::any(0,
Volatility::Immutable),
- BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead => {
- Signature::one_of(
- vec![
- TypeSignature::Any(1),
- TypeSignature::Any(2),
- TypeSignature::Any(3),
- ],
- Volatility::Immutable,
- )
- }
- BuiltInWindowFunction::FirstValue |
BuiltInWindowFunction::LastValue => {
- Signature::any(1, Volatility::Immutable)
- }
- BuiltInWindowFunction::Ntile => Signature::uniform(
- 1,
- vec![
- DataType::UInt64,
- DataType::UInt32,
- DataType::UInt16,
- DataType::UInt8,
- DataType::Int64,
- DataType::Int32,
- DataType::Int16,
- DataType::Int8,
- ],
- Volatility::Immutable,
- ),
- BuiltInWindowFunction::NthValue => Signature::any(2,
Volatility::Immutable),
- }
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use strum::IntoEnumIterator;
-
- #[test]
- fn test_count_return_type() -> Result<()> {
- let fun = find_df_window_func("count").unwrap();
- let observed = fun.return_type(&[DataType::Utf8])?;
- assert_eq!(DataType::Int64, observed);
-
- let observed = fun.return_type(&[DataType::UInt64])?;
- assert_eq!(DataType::Int64, observed);
-
- Ok(())
- }
-
- #[test]
- fn test_first_value_return_type() -> Result<()> {
- let fun = find_df_window_func("first_value").unwrap();
- let observed = fun.return_type(&[DataType::Utf8])?;
- assert_eq!(DataType::Utf8, observed);
-
- let observed = fun.return_type(&[DataType::UInt64])?;
- assert_eq!(DataType::UInt64, observed);
-
- Ok(())
- }
-
- #[test]
- fn test_last_value_return_type() -> Result<()> {
- let fun = find_df_window_func("last_value").unwrap();
- let observed = fun.return_type(&[DataType::Utf8])?;
- assert_eq!(DataType::Utf8, observed);
-
- let observed = fun.return_type(&[DataType::Float64])?;
- assert_eq!(DataType::Float64, observed);
-
- Ok(())
- }
-
- #[test]
- fn test_lead_return_type() -> Result<()> {
- let fun = find_df_window_func("lead").unwrap();
- let observed = fun.return_type(&[DataType::Utf8])?;
- assert_eq!(DataType::Utf8, observed);
-
- let observed = fun.return_type(&[DataType::Float64])?;
- assert_eq!(DataType::Float64, observed);
-
- Ok(())
- }
-
- #[test]
- fn test_lag_return_type() -> Result<()> {
- let fun = find_df_window_func("lag").unwrap();
- let observed = fun.return_type(&[DataType::Utf8])?;
- assert_eq!(DataType::Utf8, observed);
-
- let observed = fun.return_type(&[DataType::Float64])?;
- assert_eq!(DataType::Float64, observed);
-
- Ok(())
- }
-
- #[test]
- fn test_nth_value_return_type() -> Result<()> {
- let fun = find_df_window_func("nth_value").unwrap();
- let observed = fun.return_type(&[DataType::Utf8, DataType::UInt64])?;
- assert_eq!(DataType::Utf8, observed);
-
- let observed = fun.return_type(&[DataType::Float64,
DataType::UInt64])?;
- assert_eq!(DataType::Float64, observed);
-
- Ok(())
- }
-
- #[test]
- fn test_percent_rank_return_type() -> Result<()> {
- let fun = find_df_window_func("percent_rank").unwrap();
- let observed = fun.return_type(&[])?;
- assert_eq!(DataType::Float64, observed);
-
- Ok(())
- }
-
- #[test]
- fn test_cume_dist_return_type() -> Result<()> {
- let fun = find_df_window_func("cume_dist").unwrap();
- let observed = fun.return_type(&[])?;
- assert_eq!(DataType::Float64, observed);
-
- Ok(())
- }
-
- #[test]
- fn test_ntile_return_type() -> Result<()> {
- let fun = find_df_window_func("ntile").unwrap();
- let observed = fun.return_type(&[DataType::Int16])?;
- assert_eq!(DataType::UInt64, observed);
-
- Ok(())
- }
-
- #[test]
- fn test_window_function_case_insensitive() -> Result<()> {
- let names = vec![
- "row_number",
- "rank",
- "dense_rank",
- "percent_rank",
- "cume_dist",
- "ntile",
- "lag",
- "lead",
- "first_value",
- "last_value",
- "nth_value",
- "min",
- "max",
- "count",
- "avg",
- "sum",
- ];
- for name in names {
- let fun = find_df_window_func(name).unwrap();
- let fun2 =
find_df_window_func(name.to_uppercase().as_str()).unwrap();
- assert_eq!(fun, fun2);
- assert_eq!(fun.to_string(), name.to_uppercase());
- }
- Ok(())
- }
-
- #[test]
- fn test_find_df_window_function() {
- assert_eq!(
- find_df_window_func("max"),
- Some(WindowFunction::AggregateFunction(AggregateFunction::Max))
- );
- assert_eq!(
- find_df_window_func("min"),
- Some(WindowFunction::AggregateFunction(AggregateFunction::Min))
- );
- assert_eq!(
- find_df_window_func("avg"),
- Some(WindowFunction::AggregateFunction(AggregateFunction::Avg))
- );
- assert_eq!(
- find_df_window_func("cume_dist"),
- Some(WindowFunction::BuiltInWindowFunction(
- BuiltInWindowFunction::CumeDist
- ))
- );
- assert_eq!(
- find_df_window_func("first_value"),
- Some(WindowFunction::BuiltInWindowFunction(
- BuiltInWindowFunction::FirstValue
- ))
- );
- assert_eq!(
- find_df_window_func("LAST_value"),
- Some(WindowFunction::BuiltInWindowFunction(
- BuiltInWindowFunction::LastValue
- ))
- );
- assert_eq!(
- find_df_window_func("LAG"),
- Some(WindowFunction::BuiltInWindowFunction(
- BuiltInWindowFunction::Lag
- ))
- );
- assert_eq!(
- find_df_window_func("LEAD"),
- Some(WindowFunction::BuiltInWindowFunction(
- BuiltInWindowFunction::Lead
- ))
- );
- assert_eq!(find_df_window_func("not_exist"), None)
- }
-
- #[test]
- // Test for BuiltInWindowFunction's Display and from_str() implementations.
- // For each variant in BuiltInWindowFunction, it converts the variant to a
string
- // and then back to a variant. The test asserts that the original variant
and
- // the reconstructed variant are the same. This assertion is also
necessary for
- // function suggestion. See
https://github.com/apache/arrow-datafusion/issues/8082
- fn test_display_and_from_str() {
- for func_original in BuiltInWindowFunction::iter() {
- let func_name = func_original.to_string();
- let func_from_str =
BuiltInWindowFunction::from_str(&func_name).unwrap();
- assert_eq!(func_from_str, func_original);
- }
- }
-}
diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
index fd84bb8016..953716713e 100644
--- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
+++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
@@ -24,7 +24,7 @@ use datafusion_expr::expr_rewriter::rewrite_preserving_name;
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_expr::Expr::ScalarSubquery;
use datafusion_expr::{
- aggregate_function, expr, lit, window_function, Aggregate, Expr, Filter,
LogicalPlan,
+ aggregate_function, expr, lit, Aggregate, Expr, Filter, LogicalPlan,
LogicalPlanBuilder, Projection, Sort, Subquery,
};
use std::sync::Arc;
@@ -121,7 +121,7 @@ impl TreeNodeRewriter for CountWildcardRewriter {
let new_expr = match old_expr.clone() {
Expr::WindowFunction(expr::WindowFunction {
fun:
- window_function::WindowFunction::AggregateFunction(
+ expr::WindowFunctionDefinition::AggregateFunction(
aggregate_function::AggregateFunction::Count,
),
args,
@@ -131,7 +131,7 @@ impl TreeNodeRewriter for CountWildcardRewriter {
}) if args.len() == 1 => match args[0] {
Expr::Wildcard { qualifier: None } => {
Expr::WindowFunction(expr::WindowFunction {
- fun:
window_function::WindowFunction::AggregateFunction(
+ fun: expr::WindowFunctionDefinition::AggregateFunction(
aggregate_function::AggregateFunction::Count,
),
args: vec![lit(COUNT_STAR_EXPANSION)],
@@ -229,7 +229,7 @@ mod tests {
use datafusion_expr::{
col, count, exists, expr, in_subquery, lit,
logical_plan::LogicalPlanBuilder,
max, out_ref_col, scalar_subquery, wildcard, AggregateFunction, Expr,
- WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction,
+ WindowFrame, WindowFrameBound, WindowFrameUnits,
WindowFunctionDefinition,
};
fn assert_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> {
@@ -342,7 +342,7 @@ mod tests {
let plan = LogicalPlanBuilder::from(table_scan)
.window(vec![Expr::WindowFunction(expr::WindowFunction::new(
- WindowFunction::AggregateFunction(AggregateFunction::Count),
+
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count),
vec![wildcard()],
vec![],
vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))],
diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs
b/datafusion/optimizer/src/analyzer/type_coercion.rs
index b6298f5b55..4d54dad996 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -45,9 +45,9 @@ use datafusion_expr::type_coercion::{is_datetime,
is_utf8_or_large_utf8};
use datafusion_expr::utils::merge_schema;
use datafusion_expr::{
is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown,
- type_coercion, window_function, AggregateFunction, BuiltinScalarFunction,
Expr,
- ExprSchemable, LogicalPlan, Operator, Projection, ScalarFunctionDefinition,
- Signature, WindowFrame, WindowFrameBound, WindowFrameUnits,
+ type_coercion, AggregateFunction, BuiltinScalarFunction, Expr,
ExprSchemable,
+ LogicalPlan, Operator, Projection, ScalarFunctionDefinition, Signature,
WindowFrame,
+ WindowFrameBound, WindowFrameUnits,
};
use crate::analyzer::AnalyzerRule;
@@ -390,7 +390,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
coerce_window_frame(window_frame, &self.schema,
&order_by)?;
let args = match &fun {
- window_function::WindowFunction::AggregateFunction(fun) =>
{
+ expr::WindowFunctionDefinition::AggregateFunction(fun) => {
coerce_agg_exprs_for_signature(
fun,
&args,
diff --git a/datafusion/optimizer/src/push_down_projection.rs
b/datafusion/optimizer/src/push_down_projection.rs
index 10cc1879ae..4ee4f7e417 100644
--- a/datafusion/optimizer/src/push_down_projection.rs
+++ b/datafusion/optimizer/src/push_down_projection.rs
@@ -37,7 +37,7 @@ mod tests {
};
use datafusion_expr::{
col, count, lit, max, min, AggregateFunction, Expr, LogicalPlan,
Projection,
- WindowFrame, WindowFunction,
+ WindowFrame, WindowFunctionDefinition,
};
#[test]
@@ -582,7 +582,7 @@ mod tests {
let table_scan = test_table_scan()?;
let max1 = Expr::WindowFunction(expr::WindowFunction::new(
- WindowFunction::AggregateFunction(AggregateFunction::Max),
+
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
vec![col("test.a")],
vec![col("test.b")],
vec![],
@@ -590,7 +590,7 @@ mod tests {
));
let max2 = Expr::WindowFunction(expr::WindowFunction::new(
- WindowFunction::AggregateFunction(AggregateFunction::Max),
+
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
vec![col("test.b")],
vec![],
vec![],
diff --git a/datafusion/physical-plan/src/windows/mod.rs
b/datafusion/physical-plan/src/windows/mod.rs
index 3187e6b0fb..fec168fabf 100644
--- a/datafusion/physical-plan/src/windows/mod.rs
+++ b/datafusion/physical-plan/src/windows/mod.rs
@@ -34,8 +34,8 @@ use arrow::datatypes::Schema;
use arrow_schema::{DataType, Field, SchemaRef};
use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue};
use datafusion_expr::{
- window_function::{BuiltInWindowFunction, WindowFunction},
- PartitionEvaluator, WindowFrame, WindowUDF,
+ BuiltInWindowFunction, PartitionEvaluator, WindowFrame,
WindowFunctionDefinition,
+ WindowUDF,
};
use datafusion_physical_expr::equivalence::collapse_lex_req;
use datafusion_physical_expr::{
@@ -56,7 +56,7 @@ pub use datafusion_physical_expr::window::{
/// Create a physical expression for window function
pub fn create_window_expr(
- fun: &WindowFunction,
+ fun: &WindowFunctionDefinition,
name: String,
args: &[Arc<dyn PhysicalExpr>],
partition_by: &[Arc<dyn PhysicalExpr>],
@@ -65,7 +65,7 @@ pub fn create_window_expr(
input_schema: &Schema,
) -> Result<Arc<dyn WindowExpr>> {
Ok(match fun {
- WindowFunction::AggregateFunction(fun) => {
+ WindowFunctionDefinition::AggregateFunction(fun) => {
let aggregate = aggregates::create_aggregate_expr(
fun,
false,
@@ -81,13 +81,15 @@ pub fn create_window_expr(
aggregate,
)
}
- WindowFunction::BuiltInWindowFunction(fun) =>
Arc::new(BuiltInWindowExpr::new(
- create_built_in_window_expr(fun, args, input_schema, name)?,
- partition_by,
- order_by,
- window_frame,
- )),
- WindowFunction::AggregateUDF(fun) => {
+ WindowFunctionDefinition::BuiltInWindowFunction(fun) => {
+ Arc::new(BuiltInWindowExpr::new(
+ create_built_in_window_expr(fun, args, input_schema, name)?,
+ partition_by,
+ order_by,
+ window_frame,
+ ))
+ }
+ WindowFunctionDefinition::AggregateUDF(fun) => {
let aggregate =
udaf::create_aggregate_expr(fun.as_ref(), args, input_schema,
name)?;
window_expr_from_aggregate_expr(
@@ -97,7 +99,7 @@ pub fn create_window_expr(
aggregate,
)
}
- WindowFunction::WindowUDF(fun) => Arc::new(BuiltInWindowExpr::new(
+ WindowFunctionDefinition::WindowUDF(fun) =>
Arc::new(BuiltInWindowExpr::new(
create_udwf_window_expr(fun, args, input_schema, name)?,
partition_by,
order_by,
@@ -647,7 +649,7 @@ mod tests {
let refs = blocking_exec.refs();
let window_agg_exec = Arc::new(WindowAggExec::try_new(
vec![create_window_expr(
- &WindowFunction::AggregateFunction(AggregateFunction::Count),
+
&WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count),
"count".to_owned(),
&[col("a", &schema)?],
&[],
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs
b/datafusion/proto/src/logical_plan/from_proto.rs
index c582e92dc1..36c5b44f00 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -1112,7 +1112,7 @@ pub fn parse_expr(
let aggr_function = parse_i32_to_aggregate_function(i)?;
Ok(Expr::WindowFunction(WindowFunction::new(
-
datafusion_expr::window_function::WindowFunction::AggregateFunction(
+
datafusion_expr::expr::WindowFunctionDefinition::AggregateFunction(
aggr_function,
),
vec![parse_required_expr(expr.expr.as_deref(),
registry, "expr")?],
@@ -1131,7 +1131,7 @@ pub fn parse_expr(
.unwrap_or_else(Vec::new);
Ok(Expr::WindowFunction(WindowFunction::new(
-
datafusion_expr::window_function::WindowFunction::BuiltInWindowFunction(
+
datafusion_expr::expr::WindowFunctionDefinition::BuiltInWindowFunction(
built_in_function,
),
args,
@@ -1146,7 +1146,7 @@ pub fn parse_expr(
.map(|e| vec![e])
.unwrap_or_else(Vec::new);
Ok(Expr::WindowFunction(WindowFunction::new(
-
datafusion_expr::window_function::WindowFunction::AggregateUDF(
+
datafusion_expr::expr::WindowFunctionDefinition::AggregateUDF(
udaf_function,
),
args,
@@ -1161,7 +1161,7 @@ pub fn parse_expr(
.map(|e| vec![e])
.unwrap_or_else(Vec::new);
Ok(Expr::WindowFunction(WindowFunction::new(
-
datafusion_expr::window_function::WindowFunction::WindowUDF(
+
datafusion_expr::expr::WindowFunctionDefinition::WindowUDF(
udwf_function,
),
args,
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs
b/datafusion/proto/src/logical_plan/to_proto.rs
index b9987ff6c7..a162b2389c 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -51,7 +51,7 @@ use datafusion_expr::expr::{
use datafusion_expr::{
logical_plan::PlanType, logical_plan::StringifiedPlan, AggregateFunction,
BuiltInWindowFunction, BuiltinScalarFunction, Expr, JoinConstraint,
JoinType,
- TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction,
+ TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits,
WindowFunctionDefinition,
};
#[derive(Debug)]
@@ -605,22 +605,22 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
ref window_frame,
}) => {
let window_function = match fun {
- WindowFunction::AggregateFunction(fun) => {
+ WindowFunctionDefinition::AggregateFunction(fun) => {
protobuf::window_expr_node::WindowFunction::AggrFunction(
protobuf::AggregateFunction::from(fun).into(),
)
}
- WindowFunction::BuiltInWindowFunction(fun) => {
+ WindowFunctionDefinition::BuiltInWindowFunction(fun) => {
protobuf::window_expr_node::WindowFunction::BuiltInFunction(
protobuf::BuiltInWindowFunction::from(fun).into(),
)
}
- WindowFunction::AggregateUDF(aggr_udf) => {
+ WindowFunctionDefinition::AggregateUDF(aggr_udf) => {
protobuf::window_expr_node::WindowFunction::Udaf(
aggr_udf.name().to_string(),
)
}
- WindowFunction::WindowUDF(window_udf) => {
+ WindowFunctionDefinition::WindowUDF(window_udf) => {
protobuf::window_expr_node::WindowFunction::Udwf(
window_udf.name().to_string(),
)
diff --git a/datafusion/proto/src/physical_plan/from_proto.rs
b/datafusion/proto/src/physical_plan/from_proto.rs
index 8ad6d679df..23ab813ca7 100644
--- a/datafusion/proto/src/physical_plan/from_proto.rs
+++ b/datafusion/proto/src/physical_plan/from_proto.rs
@@ -31,7 +31,7 @@ use datafusion::datasource::object_store::ObjectStoreUrl;
use datafusion::datasource::physical_plan::{FileScanConfig, FileSinkConfig};
use datafusion::execution::context::ExecutionProps;
use datafusion::execution::FunctionRegistry;
-use datafusion::logical_expr::window_function::WindowFunction;
+use datafusion::logical_expr::WindowFunctionDefinition;
use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr};
use datafusion::physical_plan::expressions::{
in_list, BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr,
IsNullExpr, LikeExpr,
@@ -414,7 +414,9 @@ fn parse_required_physical_expr(
})
}
-impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> for
WindowFunction {
+impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction>
+ for WindowFunctionDefinition
+{
type Error = DataFusionError;
fn try_from(
@@ -428,7 +430,7 @@ impl
TryFrom<&protobuf::physical_window_expr_node::WindowFunction> for WindowFun
))
})?;
- Ok(WindowFunction::AggregateFunction(f.into()))
+ Ok(WindowFunctionDefinition::AggregateFunction(f.into()))
}
protobuf::physical_window_expr_node::WindowFunction::BuiltInFunction(n) => {
let f =
protobuf::BuiltInWindowFunction::try_from(*n).map_err(|_| {
@@ -437,7 +439,7 @@ impl
TryFrom<&protobuf::physical_window_expr_node::WindowFunction> for WindowFun
))
})?;
- Ok(WindowFunction::BuiltInWindowFunction(f.into()))
+ Ok(WindowFunctionDefinition::BuiltInWindowFunction(f.into()))
}
}
}
diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
index 2d7d85abda..dea99f91e3 100644
--- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
@@ -53,7 +53,7 @@ use datafusion_expr::{
col, create_udaf, lit, Accumulator, AggregateFunction,
BuiltinScalarFunction::{Sqrt, Substr},
Expr, LogicalPlan, Operator, PartitionEvaluator, Signature, TryCast,
Volatility,
- WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, WindowUDF,
+ WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
WindowUDF,
};
use datafusion_proto::bytes::{
logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec,
@@ -1663,8 +1663,8 @@ fn roundtrip_window() {
// 1. without window_frame
let test_expr1 = Expr::WindowFunction(expr::WindowFunction::new(
- WindowFunction::BuiltInWindowFunction(
- datafusion_expr::window_function::BuiltInWindowFunction::Rank,
+ WindowFunctionDefinition::BuiltInWindowFunction(
+ datafusion_expr::BuiltInWindowFunction::Rank,
),
vec![],
vec![col("col1")],
@@ -1674,8 +1674,8 @@ fn roundtrip_window() {
// 2. with default window_frame
let test_expr2 = Expr::WindowFunction(expr::WindowFunction::new(
- WindowFunction::BuiltInWindowFunction(
- datafusion_expr::window_function::BuiltInWindowFunction::Rank,
+ WindowFunctionDefinition::BuiltInWindowFunction(
+ datafusion_expr::BuiltInWindowFunction::Rank,
),
vec![],
vec![col("col1")],
@@ -1691,8 +1691,8 @@ fn roundtrip_window() {
};
let test_expr3 = Expr::WindowFunction(expr::WindowFunction::new(
- WindowFunction::BuiltInWindowFunction(
- datafusion_expr::window_function::BuiltInWindowFunction::Rank,
+ WindowFunctionDefinition::BuiltInWindowFunction(
+ datafusion_expr::BuiltInWindowFunction::Rank,
),
vec![],
vec![col("col1")],
@@ -1708,7 +1708,7 @@ fn roundtrip_window() {
};
let test_expr4 = Expr::WindowFunction(expr::WindowFunction::new(
- WindowFunction::AggregateFunction(AggregateFunction::Max),
+ WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
vec![col("col1")],
vec![col("col1")],
vec![col("col2")],
@@ -1759,7 +1759,7 @@ fn roundtrip_window() {
);
let test_expr5 = Expr::WindowFunction(expr::WindowFunction::new(
- WindowFunction::AggregateUDF(Arc::new(dummy_agg.clone())),
+ WindowFunctionDefinition::AggregateUDF(Arc::new(dummy_agg.clone())),
vec![col("col1")],
vec![col("col1")],
vec![col("col2")],
@@ -1808,7 +1808,7 @@ fn roundtrip_window() {
);
let test_expr6 = Expr::WindowFunction(expr::WindowFunction::new(
- WindowFunction::WindowUDF(Arc::new(dummy_window_udf.clone())),
+
WindowFunctionDefinition::WindowUDF(Arc::new(dummy_window_udf.clone())),
vec![col("col1")],
vec![col("col1")],
vec![col("col2")],
diff --git a/datafusion/sql/src/expr/function.rs
b/datafusion/sql/src/expr/function.rs
index 3934d6701c..395f10b6f7 100644
--- a/datafusion/sql/src/expr/function.rs
+++ b/datafusion/sql/src/expr/function.rs
@@ -23,8 +23,8 @@ use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::function::suggest_valid_function;
use datafusion_expr::window_frame::{check_window_frame,
regularize_window_order_by};
use datafusion_expr::{
- expr, window_function, AggregateFunction, BuiltinScalarFunction, Expr,
WindowFrame,
- WindowFunction,
+ expr, AggregateFunction, BuiltinScalarFunction, Expr, WindowFrame,
+ WindowFunctionDefinition,
};
use sqlparser::ast::{
Expr as SQLExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr,
WindowType,
@@ -121,12 +121,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
if let Ok(fun) = self.find_window_func(&name) {
let expr = match fun {
- WindowFunction::AggregateFunction(aggregate_fun) => {
+ WindowFunctionDefinition::AggregateFunction(aggregate_fun)
=> {
let args =
self.function_args_to_expr(args, schema,
planner_context)?;
Expr::WindowFunction(expr::WindowFunction::new(
- WindowFunction::AggregateFunction(aggregate_fun),
+
WindowFunctionDefinition::AggregateFunction(aggregate_fun),
args,
partition_by,
order_by,
@@ -191,19 +191,22 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args)))
}
- pub(super) fn find_window_func(&self, name: &str) ->
Result<WindowFunction> {
- window_function::find_df_window_func(name)
+ pub(super) fn find_window_func(
+ &self,
+ name: &str,
+ ) -> Result<WindowFunctionDefinition> {
+ expr::find_df_window_func(name)
// next check user defined aggregates
.or_else(|| {
self.context_provider
.get_aggregate_meta(name)
- .map(WindowFunction::AggregateUDF)
+ .map(WindowFunctionDefinition::AggregateUDF)
})
// next check user defined window functions
.or_else(|| {
self.context_provider
.get_window_meta(name)
- .map(WindowFunction::WindowUDF)
+ .map(WindowFunctionDefinition::WindowUDF)
})
.ok_or_else(|| {
plan_datafusion_err!("There is no window function named
{name}")
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs
b/datafusion/substrait/src/logical_plan/consumer.rs
index 9931dd15ae..a4ec3e7722 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -23,8 +23,8 @@ use datafusion::common::{
use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::{
- aggregate_function, window_function::find_df_window_func, BinaryExpr,
- BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator,
+ aggregate_function, expr::find_df_window_func, BinaryExpr,
BuiltinScalarFunction,
+ Case, Expr, LogicalPlan, Operator,
};
use datafusion::logical_expr::{
expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning,