This is an automated email from the ASF dual-hosted git repository.
jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 888504a8da Introduce Sum UDAF (#10651)
888504a8da is described below
commit 888504a8da6d20f9caf3ecb6cd1a6b7d1956e23e
Author: Jay Zhan <[email protected]>
AuthorDate: Mon Jun 3 19:43:30 2024 +0800
Introduce Sum UDAF (#10651)
* move accumulate
Signed-off-by: jayzhan211 <[email protected]>
* move prim_op
Signed-off-by: jayzhan211 <[email protected]>
* move test to slt
Signed-off-by: jayzhan211 <[email protected]>
* remove sum distinct
Signed-off-by: jayzhan211 <[email protected]>
* move sum aggregate
Signed-off-by: jayzhan211 <[email protected]>
* fix args
Signed-off-by: jayzhan211 <[email protected]>
* add sum
Signed-off-by: jayzhan211 <[email protected]>
* merge fix
Signed-off-by: jayzhan211 <[email protected]>
* fix sum sig
Signed-off-by: jayzhan211 <[email protected]>
* todo: wait ahash merge
Signed-off-by: jayzhan211 <[email protected]>
* rebase
Signed-off-by: jayzhan211 <[email protected]>
* disable ordering req by default
Signed-off-by: jayzhan211 <[email protected]>
* check arg count
Signed-off-by: jayzhan211 <[email protected]>
* rm old workflow
Signed-off-by: jayzhan211 <[email protected]>
* fmt
Signed-off-by: jayzhan211 <[email protected]>
* fix failed test
Signed-off-by: jayzhan211 <[email protected]>
* doc and fmt
Signed-off-by: jayzhan211 <[email protected]>
* check udaf first
Signed-off-by: jayzhan211 <[email protected]>
* fmt
Signed-off-by: jayzhan211 <[email protected]>
* fix ci
Signed-off-by: jayzhan211 <[email protected]>
* fix ci
Signed-off-by: jayzhan211 <[email protected]>
* fix ci
Signed-off-by: jayzhan211 <[email protected]>
* fix err msg AGAIN
Signed-off-by: jayzhan211 <[email protected]>
* rm sum in builtin test which covered in sql
Signed-off-by: jayzhan211 <[email protected]>
* proto for window with udaf
Signed-off-by: jayzhan211 <[email protected]>
* fix slt
Signed-off-by: jayzhan211 <[email protected]>
* fmt
Signed-off-by: jayzhan211 <[email protected]>
* fix err msg
Signed-off-by: jayzhan211 <[email protected]>
* fix exprfn
Signed-off-by: jayzhan211 <[email protected]>
* fix ciy
Signed-off-by: jayzhan211 <[email protected]>
* fix ci
Signed-off-by: jayzhan211 <[email protected]>
* rename first/last to lowercase
Signed-off-by: jayzhan211 <[email protected]>
* skip sum
Signed-off-by: jayzhan211 <[email protected]>
* fix firstvalue
Signed-off-by: jayzhan211 <[email protected]>
* clippy
Signed-off-by: jayzhan211 <[email protected]>
* add doc
Signed-off-by: jayzhan211 <[email protected]>
* rm has_ordering_req
Signed-off-by: jayzhan211 <[email protected]>
* default hard req
Signed-off-by: jayzhan211 <[email protected]>
* insensitive for sum
Signed-off-by: jayzhan211 <[email protected]>
* cleanup duplicate code
Signed-off-by: jayzhan211 <[email protected]>
* Re-introduce check
---------
Signed-off-by: jayzhan211 <[email protected]>
Co-authored-by: Mustafa Akur <[email protected]>
---
datafusion-cli/Cargo.lock | 1 +
datafusion-examples/examples/advanced_udaf.rs | 5 +-
.../examples/simplify_udaf_expression.rs | 6 +-
datafusion/core/src/dataframe/mod.rs | 8 +-
datafusion/core/src/physical_planner.rs | 5 +-
datafusion/core/src/prelude.rs | 1 -
datafusion/core/tests/dataframe/mod.rs | 6 +-
datafusion/core/tests/fuzz_cases/window_fuzz.rs | 12 +-
.../tests/user_defined/user_defined_aggregates.rs | 7 +-
.../user_defined/user_defined_scalar_functions.rs | 4 +-
datafusion/expr/src/built_in_window_function.rs | 4 +-
datafusion/expr/src/expr.rs | 18 +-
datafusion/expr/src/expr_fn.rs | 2 +
datafusion/expr/src/expr_schema.rs | 39 +-
datafusion/expr/src/function.rs | 6 +-
datafusion/expr/src/lib.rs | 2 +-
datafusion/expr/src/udaf.rs | 82 +++-
datafusion/functions-aggregate/Cargo.toml | 1 +
datafusion/functions-aggregate/src/first_last.rs | 7 +-
datafusion/functions-aggregate/src/lib.rs | 8 +
datafusion/functions-aggregate/src/sum.rs | 436 +++++++++++++++++++++
datafusion/optimizer/src/analyzer/type_coercion.rs | 18 +-
.../src/simplify_expressions/expr_simplifier.rs | 5 +-
.../optimizer/src/single_distinct_to_groupby.rs | 5 +-
.../physical-expr-common/src/aggregate/mod.rs | 120 ++++--
.../physical-expr-common/src/aggregate/utils.rs | 3 +-
datafusion/physical-expr/src/aggregate/build_in.rs | 102 +----
datafusion/physical-plan/src/windows/mod.rs | 46 ++-
datafusion/proto/proto/datafusion.proto | 2 +-
datafusion/proto/src/generated/pbjson.rs | 13 +
datafusion/proto/src/generated/prost.rs | 5 +-
datafusion/proto/src/physical_plan/from_proto.rs | 62 ++-
datafusion/proto/src/physical_plan/to_proto.rs | 32 +-
.../proto/tests/cases/roundtrip_logical_plan.rs | 4 +-
.../proto/tests/cases/roundtrip_physical_plan.rs | 21 +-
datafusion/sql/src/expr/function.rs | 34 +-
datafusion/sqllogictest/test_files/aggregate.slt | 12 +-
datafusion/sqllogictest/test_files/order.slt | 2 +-
.../sqllogictest/test_files/sort_merge_join.slt | 1 +
datafusion/sqllogictest/test_files/unnest.slt | 4 +-
datafusion/sqllogictest/test_files/window.slt | 36 +-
41 files changed, 888 insertions(+), 299 deletions(-)
diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock
index 6a1ba8aba0..3040586501 100644
--- a/datafusion-cli/Cargo.lock
+++ b/datafusion-cli/Cargo.lock
@@ -1287,6 +1287,7 @@ dependencies = [
name = "datafusion-functions-aggregate"
version = "38.0.0"
dependencies = [
+ "ahash",
"arrow",
"arrow-schema",
"datafusion-common",
diff --git a/datafusion-examples/examples/advanced_udaf.rs
b/datafusion-examples/examples/advanced_udaf.rs
index cf28447221..2c672a18a7 100644
--- a/datafusion-examples/examples/advanced_udaf.rs
+++ b/datafusion-examples/examples/advanced_udaf.rs
@@ -105,7 +105,10 @@ impl AggregateUDFImpl for GeoMeanUdaf {
true
}
- fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
+ fn create_groups_accumulator(
+ &self,
+ _args: AccumulatorArgs,
+ ) -> Result<Box<dyn GroupsAccumulator>> {
Ok(Box::new(GeometricMeanGroupsAccumulator::new()))
}
}
diff --git a/datafusion-examples/examples/simplify_udaf_expression.rs
b/datafusion-examples/examples/simplify_udaf_expression.rs
index 08b6bcab01..d2c8c6a86c 100644
--- a/datafusion-examples/examples/simplify_udaf_expression.rs
+++ b/datafusion-examples/examples/simplify_udaf_expression.rs
@@ -78,9 +78,13 @@ impl AggregateUDFImpl for BetterAvgUdaf {
true
}
- fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
+ fn create_groups_accumulator(
+ &self,
+ _args: AccumulatorArgs,
+ ) -> Result<Box<dyn GroupsAccumulator>> {
unimplemented!("should not get here");
}
+
// we override method, to return new expression which would substitute
// user defined function call
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
diff --git a/datafusion/core/src/dataframe/mod.rs
b/datafusion/core/src/dataframe/mod.rs
index aac506d48b..5b1aef5d2b 100644
--- a/datafusion/core/src/dataframe/mod.rs
+++ b/datafusion/core/src/dataframe/mod.rs
@@ -53,8 +53,9 @@ use datafusion_expr::{
avg, count, max, min, stddev, utils::COUNT_STAR_EXPANSION,
TableProviderFilterPushDown, UNNAMED_TABLE,
};
-use datafusion_expr::{case, is_null, sum};
+use datafusion_expr::{case, is_null};
use datafusion_functions_aggregate::expr_fn::median;
+use datafusion_functions_aggregate::expr_fn::sum;
use async_trait::async_trait;
@@ -1593,9 +1594,8 @@ mod tests {
use datafusion_common::{Constraint, Constraints};
use datafusion_common_runtime::SpawnedTask;
use datafusion_expr::{
- array_agg, cast, count_distinct, create_udf, expr, lit, sum,
- BuiltInWindowFunction, ScalarFunctionImplementation, Volatility,
WindowFrame,
- WindowFunctionDefinition,
+ array_agg, cast, count_distinct, create_udf, expr, lit,
BuiltInWindowFunction,
+ ScalarFunctionImplementation, Volatility, WindowFrame,
WindowFunctionDefinition,
};
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties};
diff --git a/datafusion/core/src/physical_planner.rs
b/datafusion/core/src/physical_planner.rs
index 5e2e546a86..3bc8983532 100644
--- a/datafusion/core/src/physical_planner.rs
+++ b/datafusion/core/src/physical_planner.rs
@@ -2257,9 +2257,8 @@ mod tests {
use datafusion_common::{assert_contains, DFSchemaRef, TableReference};
use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_execution::TaskContext;
- use datafusion_expr::{
- col, lit, sum, LogicalPlanBuilder, UserDefinedLogicalNodeCore,
- };
+ use datafusion_expr::{col, lit, LogicalPlanBuilder,
UserDefinedLogicalNodeCore};
+ use datafusion_functions_aggregate::expr_fn::sum;
use datafusion_physical_expr::EquivalenceProperties;
fn make_session_state() -> SessionState {
diff --git a/datafusion/core/src/prelude.rs b/datafusion/core/src/prelude.rs
index 0d8d06f49b..d82a5a2cc1 100644
--- a/datafusion/core/src/prelude.rs
+++ b/datafusion/core/src/prelude.rs
@@ -39,7 +39,6 @@ pub use datafusion_expr::{
Expr,
};
pub use datafusion_functions::expr_fn::*;
-pub use datafusion_functions_aggregate::expr_fn::*;
#[cfg(feature = "array_expressions")]
pub use datafusion_functions_array::expr_fn::*;
diff --git a/datafusion/core/tests/dataframe/mod.rs
b/datafusion/core/tests/dataframe/mod.rs
index 60e60bb1e3..befd98d043 100644
--- a/datafusion/core/tests/dataframe/mod.rs
+++ b/datafusion/core/tests/dataframe/mod.rs
@@ -52,10 +52,10 @@ use datafusion_expr::expr::{GroupingSet, Sort};
use datafusion_expr::var_provider::{VarProvider, VarType};
use datafusion_expr::{
array_agg, avg, cast, col, count, exists, expr, in_subquery, lit, max,
out_ref_col,
- placeholder, scalar_subquery, sum, when, wildcard, AggregateFunction, Expr,
- ExprSchemable, WindowFrame, WindowFrameBound, WindowFrameUnits,
- WindowFunctionDefinition,
+ placeholder, scalar_subquery, when, wildcard, AggregateFunction, Expr,
ExprSchemable,
+ WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
};
+use datafusion_functions_aggregate::expr_fn::sum;
#[tokio::test]
async fn test_count_wildcard_on_sort() -> Result<()> {
diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs
b/datafusion/core/tests/fuzz_cases/window_fuzz.rs
index fe0c408dc1..b85f6376c3 100644
--- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs
@@ -33,10 +33,12 @@ use datafusion::prelude::{SessionConfig, SessionContext};
use datafusion_common::{Result, ScalarValue};
use datafusion_common_runtime::SpawnedTask;
use datafusion_expr::type_coercion::aggregates::coerce_types;
+use datafusion_expr::type_coercion::functions::data_types_with_aggregate_udf;
use datafusion_expr::{
AggregateFunction, BuiltInWindowFunction, WindowFrame, WindowFrameBound,
WindowFrameUnits, WindowFunctionDefinition,
};
+use datafusion_functions_aggregate::sum::sum_udaf;
use datafusion_physical_expr::expressions::{cast, col, lit};
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
use test_utils::add_empty_batches;
@@ -341,7 +343,7 @@ fn get_random_function(
window_fn_map.insert(
"sum",
(
-
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
+ WindowFunctionDefinition::AggregateUDF(sum_udaf()),
vec![arg.clone()],
),
);
@@ -468,6 +470,14 @@ fn get_random_function(
let coerced = coerce_types(f, &[dt], &sig).unwrap();
args[0] = cast(a, schema, coerced[0].clone()).unwrap();
}
+ } else if let WindowFunctionDefinition::AggregateUDF(udf) = window_fn {
+ if !args.is_empty() {
+ // Do type coercion first argument
+ let a = args[0].clone();
+ let dt = a.data_type(schema.as_ref()).unwrap();
+ let coerced = data_types_with_aggregate_udf(&[dt], udf).unwrap();
+ args[0] = cast(a, schema, coerced[0].clone()).unwrap();
+ }
}
(window_fn.clone(), args, fn_name.to_string())
diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs
b/datafusion/core/tests/user_defined/user_defined_aggregates.rs
index d199f04ba7..071db5adf0 100644
--- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs
+++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs
@@ -142,7 +142,7 @@ async fn
test_udaf_as_window_with_frame_without_retract_batch() {
let sql = "SELECT time_sum(time) OVER(ORDER BY time ROWS BETWEEN 1
PRECEDING AND 1 FOLLOWING) as time_sum from t";
// Note if this query ever does start working
let err = execute(&ctx, sql).await.unwrap_err();
- assert_contains!(err.to_string(), "This feature is not implemented:
Aggregate can not be used as a sliding accumulator because `retract_batch` is
not implemented: AggregateUDF { inner: AggregateUDF { name: \"time_sum\",
signature: Signature { type_signature: Exact([Timestamp(Nanosecond, None)]),
volatility: Immutable }, fun: \"<FUNC>\" } }(t.time) ORDER BY [t.time ASC NULLS
LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING");
+ assert_contains!(err.to_string(), "This feature is not implemented:
Aggregate can not be used as a sliding accumulator because `retract_batch` is
not implemented: time_sum(t.time) ORDER BY [t.time ASC NULLS LAST] ROWS BETWEEN
1 PRECEDING AND 1 FOLLOWING");
}
/// Basic query for with a udaf returning a structure
@@ -729,7 +729,10 @@ impl AggregateUDFImpl for TestGroupsAccumulator {
true
}
- fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
+ fn create_groups_accumulator(
+ &self,
+ _args: AccumulatorArgs,
+ ) -> Result<Box<dyn GroupsAccumulator>> {
Ok(Box::new(self.clone()))
}
}
diff --git
a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
index df41cab7bf..2d98b7f80f 100644
--- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
+++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
@@ -378,8 +378,8 @@ async fn udaf_as_window_func() -> Result<()> {
context.register_udaf(my_acc);
let sql = "SELECT a, MY_ACC(b) OVER(PARTITION BY a) FROM my_table";
- let expected = r#"Projection: my_table.a, AggregateUDF { inner:
AggregateUDF { name: "my_acc", signature: Signature { type_signature:
Exact([Int32]), volatility: Immutable }, fun: "<FUNC>" } }(my_table.b)
PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED
FOLLOWING
- WindowAggr: windowExpr=[[AggregateUDF { inner: AggregateUDF { name:
"my_acc", signature: Signature { type_signature: Exact([Int32]), volatility:
Immutable }, fun: "<FUNC>" } }(my_table.b) PARTITION BY [my_table.a] ROWS
BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]
+ let expected = r#"Projection: my_table.a, my_acc(my_table.b) PARTITION BY
[my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
+ WindowAggr: windowExpr=[[my_acc(my_table.b) PARTITION BY [my_table.a] ROWS
BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]
TableScan: my_table"#;
let dataframe = context.sql(sql).await.unwrap();
diff --git a/datafusion/expr/src/built_in_window_function.rs
b/datafusion/expr/src/built_in_window_function.rs
index 18a888ae8b..3885d70049 100644
--- a/datafusion/expr/src/built_in_window_function.rs
+++ b/datafusion/expr/src/built_in_window_function.rs
@@ -82,8 +82,8 @@ impl BuiltInWindowFunction {
Ntile => "NTILE",
Lag => "LAG",
Lead => "LEAD",
- FirstValue => "FIRST_VALUE",
- LastValue => "LAST_VALUE",
+ FirstValue => "first_value",
+ LastValue => "last_value",
NthValue => "NTH_VALUE",
}
}
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 71cf3adddf..14c64ef8f8 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -754,10 +754,14 @@ impl WindowFunctionDefinition {
impl fmt::Display for WindowFunctionDefinition {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
- WindowFunctionDefinition::AggregateFunction(fun) => fun.fmt(f),
- WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.fmt(f),
- WindowFunctionDefinition::AggregateUDF(fun) =>
std::fmt::Debug::fmt(fun, f),
- WindowFunctionDefinition::WindowUDF(fun) => fun.fmt(f),
+ WindowFunctionDefinition::AggregateFunction(fun) => {
+ std::fmt::Display::fmt(fun, f)
+ }
+ WindowFunctionDefinition::BuiltInWindowFunction(fun) => {
+ std::fmt::Display::fmt(fun, f)
+ }
+ WindowFunctionDefinition::AggregateUDF(fun) =>
std::fmt::Display::fmt(fun, f),
+ WindowFunctionDefinition::WindowUDF(fun) =>
std::fmt::Display::fmt(fun, f),
}
}
}
@@ -2263,7 +2267,11 @@ mod test {
let fun = find_df_window_func(name).unwrap();
let fun2 =
find_df_window_func(name.to_uppercase().as_str()).unwrap();
assert_eq!(fun, fun2);
- assert_eq!(fun.to_string(), name.to_uppercase());
+ if fun.to_string() == "first_value" || fun.to_string() ==
"last_value" {
+ assert_eq!(fun.to_string(), name);
+ } else {
+ assert_eq!(fun.to_string(), name.to_uppercase());
+ }
}
Ok(())
}
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index 8c9d3c7885..694911592b 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -169,6 +169,8 @@ pub fn max(expr: Expr) -> Expr {
}
/// Create an expression to represent the sum() aggregate function
+///
+/// TODO: Remove this function and use `sum` from
`datafusion_functions_aggregate::expr_fn` instead
pub fn sum(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
aggregate_function::AggregateFunction::Sum,
diff --git a/datafusion/expr/src/expr_schema.rs
b/datafusion/expr/src/expr_schema.rs
index 01c9edff30..57470db2e0 100644
--- a/datafusion/expr/src/expr_schema.rs
+++ b/datafusion/expr/src/expr_schema.rs
@@ -21,8 +21,10 @@ use crate::expr::{
InSubquery, Placeholder, ScalarFunction, Sort, TryCast, Unnest,
WindowFunction,
};
use crate::type_coercion::binary::get_result_type;
-use crate::type_coercion::functions::data_types_with_scalar_udf;
-use crate::{utils, LogicalPlan, Projection, Subquery};
+use crate::type_coercion::functions::{
+ data_types_with_aggregate_udf, data_types_with_scalar_udf,
+};
+use crate::{utils, LogicalPlan, Projection, Subquery,
WindowFunctionDefinition};
use arrow::compute::can_cast_types;
use arrow::datatypes::{DataType, Field};
use datafusion_common::{
@@ -158,7 +160,25 @@ impl ExprSchemable for Expr {
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
- fun.return_type(&data_types)
+ match fun {
+ WindowFunctionDefinition::AggregateUDF(udf) => {
+ let new_types =
data_types_with_aggregate_udf(&data_types, udf).map_err(|err| {
+ plan_datafusion_err!(
+ "{} and {}",
+ err,
+ utils::generate_signature_error_msg(
+ fun.name(),
+ fun.signature().clone(),
+ &data_types
+ )
+ )
+ })?;
+ Ok(fun.return_type(&new_types)?)
+ }
+ _ => {
+ fun.return_type(&data_types)
+ }
+ }
}
Expr::AggregateFunction(AggregateFunction { func_def, args, .. })
=> {
let data_types = args
@@ -170,7 +190,18 @@ impl ExprSchemable for Expr {
fun.return_type(&data_types)
}
AggregateFunctionDefinition::UDF(fun) => {
- Ok(fun.return_type(&data_types)?)
+ let new_types =
data_types_with_aggregate_udf(&data_types, fun).map_err(|err| {
+ plan_datafusion_err!(
+ "{} and {}",
+ err,
+ utils::generate_signature_error_msg(
+ fun.name(),
+ fun.signature().clone(),
+ &data_types
+ )
+ )
+ })?;
+ Ok(fun.return_type(&new_types)?)
}
}
}
diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs
index 7f49b03bb2..c06f177510 100644
--- a/datafusion/expr/src/function.rs
+++ b/datafusion/expr/src/function.rs
@@ -70,6 +70,9 @@ pub struct AccumulatorArgs<'a> {
/// If no `ORDER BY` is specified, `sort_exprs`` will be empty.
pub sort_exprs: &'a [Expr],
+ /// The name of the aggregate expression
+ pub name: &'a str,
+
/// Whether the aggregate function is distinct.
///
/// ```sql
@@ -82,9 +85,6 @@ pub struct AccumulatorArgs<'a> {
/// The number of arguments the aggregate function takes.
pub args_num: usize,
-
- /// The name of the expression
- pub name: &'a str,
}
/// [`StateFieldsArgs`] contains information about the fields that an
diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs
index 74d6b4149d..bbd1d6f654 100644
--- a/datafusion/expr/src/lib.rs
+++ b/datafusion/expr/src/lib.rs
@@ -64,7 +64,7 @@ pub use built_in_window_function::BuiltInWindowFunction;
pub use columnar_value::ColumnarValue;
pub use expr::{
Between, BinaryExpr, Case, Cast, Expr, GetFieldAccess, GetIndexedField,
GroupingSet,
- Like, TryCast, WindowFunctionDefinition,
+ Like, Sort as SortExpr, TryCast, WindowFunctionDefinition,
};
pub use expr_fn::*;
pub use expr_schema::ExprSchemable;
diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs
index 0274038a36..d778203207 100644
--- a/datafusion/expr/src/udaf.rs
+++ b/datafusion/expr/src/udaf.rs
@@ -83,6 +83,12 @@ impl std::hash::Hash for AggregateUDF {
}
}
+impl std::fmt::Display for AggregateUDF {
+ fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+ write!(f, "{}", self.name())
+ }
+}
+
impl AggregateUDF {
/// Create a new AggregateUDF
///
@@ -190,8 +196,22 @@ impl AggregateUDF {
}
/// See [`AggregateUDFImpl::create_groups_accumulator`] for more details.
- pub fn create_groups_accumulator(&self) -> Result<Box<dyn
GroupsAccumulator>> {
- self.inner.create_groups_accumulator()
+ pub fn create_groups_accumulator(
+ &self,
+ args: AccumulatorArgs,
+ ) -> Result<Box<dyn GroupsAccumulator>> {
+ self.inner.create_groups_accumulator(args)
+ }
+
+ pub fn create_sliding_accumulator(
+ &self,
+ args: AccumulatorArgs,
+ ) -> Result<Box<dyn Accumulator>> {
+ self.inner.create_sliding_accumulator(args)
+ }
+
+ pub fn coerce_types(&self, arg_types: &[DataType]) ->
Result<Vec<DataType>> {
+ self.inner.coerce_types(arg_types)
}
/// See [`AggregateUDFImpl::with_beneficial_ordering`] for more details.
@@ -213,16 +233,8 @@ impl AggregateUDF {
/// Reserves the `AggregateUDF` (e.g. returns the `AggregateUDF` that will
/// generate same result with this `AggregateUDF` when iterated in reverse
/// order, and `None` if there is no such `AggregateUDF`).
- pub fn reverse_udf(&self) -> Option<AggregateUDF> {
- match self.inner.reverse_expr() {
- ReversedUDAF::NotSupported => None,
- ReversedUDAF::Identical => Some(self.clone()),
- ReversedUDAF::Reversed(reverse) => Some(Self { inner: reverse }),
- }
- }
-
- pub fn coerce_types(&self, _arg_types: &[DataType]) ->
Result<Vec<DataType>> {
- not_impl_err!("coerce_types not implemented for {:?} yet", self.name())
+ pub fn reverse_udf(&self) -> ReversedUDAF {
+ self.inner.reverse_expr()
}
/// Do the function rewrite
@@ -327,7 +339,8 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
///
/// # Arguments:
/// 1. `name`: the name of the expression (e.g. AVG, SUM, etc)
- /// 2. `value_type`: Aggregate's aggregate's output (returned by
[`Self::return_type`])
+ /// 2. `value_type`: Aggregate function output returned by
[`Self::return_type`] if defined, otherwise
+ /// it is equivalent to the data type of the first arguments
/// 3. `ordering_fields`: the fields used to order the input arguments, if
any.
/// Empty if no ordering expression is provided.
///
@@ -377,7 +390,10 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
///
/// For maximum performance, a [`GroupsAccumulator`] should be
/// implemented in addition to [`Accumulator`].
- fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
+ fn create_groups_accumulator(
+ &self,
+ _args: AccumulatorArgs,
+ ) -> Result<Box<dyn GroupsAccumulator>> {
not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?}
yet")
}
@@ -389,6 +405,19 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
&[]
}
+ /// Sliding accumulator is an alternative accumulator that can be used for
+ /// window functions. It has retract method to revert the previous update.
+ ///
+ /// See [retract_batch] for more details.
+ ///
+ /// [retract_batch]: crate::accumulator::Accumulator::retract_batch
+ fn create_sliding_accumulator(
+ &self,
+ args: AccumulatorArgs,
+ ) -> Result<Box<dyn Accumulator>> {
+ self.accumulator(args)
+ }
+
/// Sets the indicator whether ordering requirements of the
AggregateUDFImpl is
/// satisfied by its input. If this is not the case, UDFs with order
/// sensitivity `AggregateOrderSensitivity::Beneficial` can still produce
@@ -451,6 +480,29 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
fn reverse_expr(&self) -> ReversedUDAF {
ReversedUDAF::NotSupported
}
+
+ /// Coerce arguments of a function call to types that the function can
evaluate.
+ ///
+ /// This function is only called if [`AggregateUDFImpl::signature`]
returns [`crate::TypeSignature::UserDefined`]. Most
+ /// UDAFs should return one of the other variants of `TypeSignature` which
handle common
+ /// cases
+ ///
+ /// See the [type coercion module](crate::type_coercion)
+ /// documentation for more details on type coercion
+ ///
+ /// For example, if your function requires a floating point arguments, but
the user calls
+ /// it like `my_func(1::int)` (aka with `1` as an integer), coerce_types
could return `[DataType::Float64]`
+ /// to ensure the argument was cast to `1::double`
+ ///
+ /// # Parameters
+ /// * `arg_types`: The argument types of the arguments this function with
+ ///
+ /// # Return value
+ /// A Vec the same length as `arg_types`. DataFusion will `CAST` the
function call
+ /// arguments to these specific types.
+ fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
+ not_impl_err!("Function {} does not implement coerce_types",
self.name())
+ }
}
pub enum ReversedUDAF {
@@ -459,7 +511,7 @@ pub enum ReversedUDAF {
/// The expression does not support reverse calculation, like ArrayAgg
NotSupported,
/// The expression is different from the original expression
- Reversed(Arc<dyn AggregateUDFImpl>),
+ Reversed(Arc<AggregateUDF>),
}
/// AggregateUDF that adds an alias to the underlying function. It is better to
diff --git a/datafusion/functions-aggregate/Cargo.toml
b/datafusion/functions-aggregate/Cargo.toml
index 696bbaece9..26630a0352 100644
--- a/datafusion/functions-aggregate/Cargo.toml
+++ b/datafusion/functions-aggregate/Cargo.toml
@@ -38,6 +38,7 @@ path = "src/lib.rs"
# See more keys and their definitions at
https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
+ahash = { workspace = true }
arrow = { workspace = true }
arrow-schema = { workspace = true }
datafusion-common = { workspace = true }
diff --git a/datafusion/functions-aggregate/src/first_last.rs
b/datafusion/functions-aggregate/src/first_last.rs
index f1cb92045f..fe4501c149 100644
--- a/datafusion/functions-aggregate/src/first_last.rs
+++ b/datafusion/functions-aggregate/src/first_last.rs
@@ -75,7 +75,8 @@ impl FirstValue {
vec![
// TODO: we can introduce more strict signature that only
numeric of array types are allowed
TypeSignature::ArraySignature(ArrayFunctionSignature::Array),
- TypeSignature::Uniform(1, NUMERICS.to_vec()),
+ TypeSignature::Numeric(1),
+ TypeSignature::Uniform(1, vec![DataType::Utf8]),
],
Volatility::Immutable,
),
@@ -159,7 +160,7 @@ impl AggregateUDFImpl for FirstValue {
}
fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
- datafusion_expr::ReversedUDAF::Reversed(last_value_udaf().inner())
+ datafusion_expr::ReversedUDAF::Reversed(last_value_udaf())
}
}
@@ -462,7 +463,7 @@ impl AggregateUDFImpl for LastValue {
}
fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
- datafusion_expr::ReversedUDAF::Reversed(first_value_udaf().inner())
+ datafusion_expr::ReversedUDAF::Reversed(first_value_udaf())
}
}
diff --git a/datafusion/functions-aggregate/src/lib.rs
b/datafusion/functions-aggregate/src/lib.rs
index e82897e926..cb8ef65420 100644
--- a/datafusion/functions-aggregate/src/lib.rs
+++ b/datafusion/functions-aggregate/src/lib.rs
@@ -58,6 +58,7 @@ pub mod macros;
pub mod covariance;
pub mod first_last;
pub mod median;
+pub mod sum;
use datafusion_common::Result;
use datafusion_execution::FunctionRegistry;
@@ -72,6 +73,7 @@ pub mod expr_fn {
pub use super::first_last::first_value;
pub use super::first_last::last_value;
pub use super::median::median;
+ pub use super::sum::sum;
}
/// Returns all default aggregate functions
@@ -80,6 +82,7 @@ pub fn all_default_aggregate_functions() ->
Vec<Arc<AggregateUDF>> {
first_last::first_value_udaf(),
first_last::last_value_udaf(),
covariance::covar_samp_udaf(),
+ sum::sum_udaf(),
covariance::covar_pop_udaf(),
median::median_udaf(),
]
@@ -110,6 +113,11 @@ mod tests {
fn test_no_duplicate_name() -> Result<()> {
let mut names = HashSet::new();
for func in all_default_aggregate_functions() {
+ // TODO: remove this
+ // sum is in intermidiate migration state, skip this
+ if func.name().to_lowercase() == "sum" {
+ continue;
+ }
assert!(
names.insert(func.name().to_string().to_lowercase()),
"duplicate function name: {}",
diff --git a/datafusion/functions-aggregate/src/sum.rs
b/datafusion/functions-aggregate/src/sum.rs
new file mode 100644
index 0000000000..b3127726cb
--- /dev/null
+++ b/datafusion/functions-aggregate/src/sum.rs
@@ -0,0 +1,436 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Defines `SUM` and `SUM DISTINCT` aggregate accumulators
+
+use ahash::RandomState;
+use datafusion_expr::utils::AggregateOrderSensitivity;
+use std::any::Any;
+use std::collections::HashSet;
+
+use arrow::array::Array;
+use arrow::array::ArrowNativeTypeOp;
+use arrow::array::{ArrowNumericType, AsArray};
+use arrow::datatypes::ArrowNativeType;
+use arrow::datatypes::ArrowPrimitiveType;
+use arrow::datatypes::{
+ DataType, Decimal128Type, Decimal256Type, Float64Type, Int64Type,
UInt64Type,
+ DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
+};
+use arrow::{array::ArrayRef, datatypes::Field};
+use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue};
+use datafusion_expr::function::AccumulatorArgs;
+use datafusion_expr::function::StateFieldsArgs;
+use datafusion_expr::utils::format_state_name;
+use datafusion_expr::{
+ Accumulator, AggregateUDFImpl, GroupsAccumulator, ReversedUDAF, Signature,
Volatility,
+};
+use
datafusion_physical_expr_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
+use datafusion_physical_expr_common::aggregate::utils::Hashable;
+
+make_udaf_expr_and_func!(
+ Sum,
+ sum,
+ expression,
+ "Returns the first value in a group of values.",
+ sum_udaf
+);
+
+/// Sum only supports a subset of numeric types, instead relying on type
coercion
+///
+/// This macro is similar to
[downcast_primitive](arrow::array::downcast_primitive)
+///
+/// `args` is [AccumulatorArgs]
+/// `helper` is a macro accepting (ArrowPrimitiveType, DataType)
+macro_rules! downcast_sum {
+ ($args:ident, $helper:ident) => {
+ match $args.data_type {
+ DataType::UInt64 => $helper!(UInt64Type, $args.data_type),
+ DataType::Int64 => $helper!(Int64Type, $args.data_type),
+ DataType::Float64 => $helper!(Float64Type, $args.data_type),
+ DataType::Decimal128(_, _) => $helper!(Decimal128Type,
$args.data_type),
+ DataType::Decimal256(_, _) => $helper!(Decimal256Type,
$args.data_type),
+ _ => {
+ not_impl_err!("Sum not supported for {}: {}", $args.name,
$args.data_type)
+ }
+ }
+ };
+}
+
+#[derive(Debug)]
+pub struct Sum {
+ signature: Signature,
+ aliases: Vec<String>,
+}
+
+impl Sum {
+ pub fn new() -> Self {
+ Self {
+ signature: Signature::user_defined(Volatility::Immutable),
+ aliases: vec!["sum".to_string()],
+ }
+ }
+}
+
+impl Default for Sum {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+impl AggregateUDFImpl for Sum {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn name(&self) -> &str {
+ "SUM"
+ }
+
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+
+ fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
+ if arg_types.len() != 1 {
+ return exec_err!("SUM expects exactly one argument");
+ }
+
+ // Refer to
https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
+ // smallint, int, bigint, real, double precision, decimal, or interval.
+
+ fn coerced_type(data_type: &DataType) -> Result<DataType> {
+ match data_type {
+ DataType::Dictionary(_, v) => coerced_type(v),
+ // in the spark, the result type is
DECIMAL(min(38,precision+10), s)
+ // ref:
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
+ DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => {
+ Ok(data_type.clone())
+ }
+ dt if dt.is_signed_integer() => Ok(DataType::Int64),
+ dt if dt.is_unsigned_integer() => Ok(DataType::UInt64),
+ dt if dt.is_floating() => Ok(DataType::Float64),
+ _ => exec_err!("Sum not supported for {}", data_type),
+ }
+ }
+
+ Ok(vec![coerced_type(&arg_types[0])?])
+ }
+
+ fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
+ match &arg_types[0] {
+ DataType::Int64 => Ok(DataType::Int64),
+ DataType::UInt64 => Ok(DataType::UInt64),
+ DataType::Float64 => Ok(DataType::Float64),
+ DataType::Decimal128(precision, scale) => {
+ // in the spark, the result type is
DECIMAL(min(38,precision+10), s)
+ // ref:
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
+ let new_precision = DECIMAL128_MAX_PRECISION.min(*precision +
10);
+ Ok(DataType::Decimal128(new_precision, *scale))
+ }
+ DataType::Decimal256(precision, scale) => {
+ // in the spark, the result type is
DECIMAL(min(38,precision+10), s)
+ // ref:
https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
+ let new_precision = DECIMAL256_MAX_PRECISION.min(*precision +
10);
+ Ok(DataType::Decimal256(new_precision, *scale))
+ }
+ other => {
+ exec_err!("[return_type] SUM not supported for {}", other)
+ }
+ }
+ }
+
+ fn accumulator(&self, args: AccumulatorArgs) -> Result<Box<dyn
Accumulator>> {
+ if args.is_distinct {
+ macro_rules! helper {
+ ($t:ty, $dt:expr) => {
+ Ok(Box::new(DistinctSumAccumulator::<$t>::try_new(&$dt)?))
+ };
+ }
+ downcast_sum!(args, helper)
+ } else {
+ macro_rules! helper {
+ ($t:ty, $dt:expr) => {
+ Ok(Box::new(SumAccumulator::<$t>::new($dt.clone())))
+ };
+ }
+ downcast_sum!(args, helper)
+ }
+ }
+
+ fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
+ if args.is_distinct {
+ Ok(vec![Field::new_list(
+ format_state_name(args.name, "sum distinct"),
+ Field::new("item", args.return_type.clone(), true),
+ false,
+ )])
+ } else {
+ Ok(vec![Field::new(
+ format_state_name(args.name, "sum"),
+ args.return_type.clone(),
+ true,
+ )])
+ }
+ }
+
+ fn aliases(&self) -> &[String] {
+ &self.aliases
+ }
+
+ fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
+ !args.is_distinct
+ }
+
+ fn create_groups_accumulator(
+ &self,
+ args: AccumulatorArgs,
+ ) -> Result<Box<dyn GroupsAccumulator>> {
+ macro_rules! helper {
+ ($t:ty, $dt:expr) => {
+ Ok(Box::new(PrimitiveGroupsAccumulator::<$t, _>::new(
+ &$dt,
+ |x, y| *x = x.add_wrapping(y),
+ )))
+ };
+ }
+ downcast_sum!(args, helper)
+ }
+
+ fn create_sliding_accumulator(
+ &self,
+ args: AccumulatorArgs,
+ ) -> Result<Box<dyn Accumulator>> {
+ macro_rules! helper {
+ ($t:ty, $dt:expr) => {
+ Ok(Box::new(SlidingSumAccumulator::<$t>::new($dt.clone())))
+ };
+ }
+ downcast_sum!(args, helper)
+ }
+
+ fn reverse_expr(&self) -> ReversedUDAF {
+ ReversedUDAF::Identical
+ }
+
+ fn order_sensitivity(&self) -> AggregateOrderSensitivity {
+ AggregateOrderSensitivity::Insensitive
+ }
+}
+
+/// This accumulator computes SUM incrementally
+struct SumAccumulator<T: ArrowNumericType> {
+ sum: Option<T::Native>,
+ data_type: DataType,
+}
+
+impl<T: ArrowNumericType> std::fmt::Debug for SumAccumulator<T> {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "SumAccumulator({})", self.data_type)
+ }
+}
+
+impl<T: ArrowNumericType> SumAccumulator<T> {
+ fn new(data_type: DataType) -> Self {
+ Self {
+ sum: None,
+ data_type,
+ }
+ }
+}
+
+impl<T: ArrowNumericType> Accumulator for SumAccumulator<T> {
+ fn state(&mut self) -> Result<Vec<ScalarValue>> {
+ Ok(vec![self.evaluate()?])
+ }
+
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ let values = values[0].as_primitive::<T>();
+ if let Some(x) = arrow::compute::sum(values) {
+ let v = self.sum.get_or_insert(T::Native::usize_as(0));
+ *v = v.add_wrapping(x);
+ }
+ Ok(())
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ self.update_batch(states)
+ }
+
+ fn evaluate(&mut self) -> Result<ScalarValue> {
+ ScalarValue::new_primitive::<T>(self.sum, &self.data_type)
+ }
+
+ fn size(&self) -> usize {
+ std::mem::size_of_val(self)
+ }
+}
+
+/// This accumulator incrementally computes sums over a sliding window
+///
+/// This is separate from [`SumAccumulator`] as requires additional state
+struct SlidingSumAccumulator<T: ArrowNumericType> {
+ sum: T::Native,
+ count: u64,
+ data_type: DataType,
+}
+
+impl<T: ArrowNumericType> std::fmt::Debug for SlidingSumAccumulator<T> {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "SlidingSumAccumulator({})", self.data_type)
+ }
+}
+
+impl<T: ArrowNumericType> SlidingSumAccumulator<T> {
+ fn new(data_type: DataType) -> Self {
+ Self {
+ sum: T::Native::usize_as(0),
+ count: 0,
+ data_type,
+ }
+ }
+}
+
+impl<T: ArrowNumericType> Accumulator for SlidingSumAccumulator<T> {
+ fn state(&mut self) -> Result<Vec<ScalarValue>> {
+ Ok(vec![self.evaluate()?, self.count.into()])
+ }
+
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ let values = values[0].as_primitive::<T>();
+ self.count += (values.len() - values.null_count()) as u64;
+ if let Some(x) = arrow::compute::sum(values) {
+ self.sum = self.sum.add_wrapping(x)
+ }
+ Ok(())
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ let values = states[0].as_primitive::<T>();
+ if let Some(x) = arrow::compute::sum(values) {
+ self.sum = self.sum.add_wrapping(x)
+ }
+ if let Some(x) =
arrow::compute::sum(states[1].as_primitive::<UInt64Type>()) {
+ self.count += x;
+ }
+ Ok(())
+ }
+
+ fn evaluate(&mut self) -> Result<ScalarValue> {
+ let v = (self.count != 0).then_some(self.sum);
+ ScalarValue::new_primitive::<T>(v, &self.data_type)
+ }
+
+ fn size(&self) -> usize {
+ std::mem::size_of_val(self)
+ }
+
+ fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ let values = values[0].as_primitive::<T>();
+ if let Some(x) = arrow::compute::sum(values) {
+ self.sum = self.sum.sub_wrapping(x)
+ }
+ self.count -= (values.len() - values.null_count()) as u64;
+ Ok(())
+ }
+
+ fn supports_retract_batch(&self) -> bool {
+ true
+ }
+}
+
+struct DistinctSumAccumulator<T: ArrowPrimitiveType> {
+ values: HashSet<Hashable<T::Native>, RandomState>,
+ data_type: DataType,
+}
+
+impl<T: ArrowPrimitiveType> std::fmt::Debug for DistinctSumAccumulator<T> {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "DistinctSumAccumulator({})", self.data_type)
+ }
+}
+
+impl<T: ArrowPrimitiveType> DistinctSumAccumulator<T> {
+ pub fn try_new(data_type: &DataType) -> Result<Self> {
+ Ok(Self {
+ values: HashSet::default(),
+ data_type: data_type.clone(),
+ })
+ }
+}
+
+impl<T: ArrowPrimitiveType> Accumulator for DistinctSumAccumulator<T> {
+ fn state(&mut self) -> Result<Vec<ScalarValue>> {
+ // 1. Stores aggregate state in `ScalarValue::List`
+ // 2. Constructs `ScalarValue::List` state from distinct numeric
stored in hash set
+ let state_out = {
+ let distinct_values = self
+ .values
+ .iter()
+ .map(|value| {
+ ScalarValue::new_primitive::<T>(Some(value.0),
&self.data_type)
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ vec![ScalarValue::List(ScalarValue::new_list(
+ &distinct_values,
+ &self.data_type,
+ ))]
+ };
+ Ok(state_out)
+ }
+
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ if values.is_empty() {
+ return Ok(());
+ }
+
+ let array = values[0].as_primitive::<T>();
+ match array.nulls().filter(|x| x.null_count() > 0) {
+ Some(n) => {
+ for idx in n.valid_indices() {
+ self.values.insert(Hashable(array.value(idx)));
+ }
+ }
+ None => array.values().iter().for_each(|x| {
+ self.values.insert(Hashable(*x));
+ }),
+ }
+ Ok(())
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ for x in states[0].as_list::<i32>().iter().flatten() {
+ self.update_batch(&[x])?
+ }
+ Ok(())
+ }
+
+ fn evaluate(&mut self) -> Result<ScalarValue> {
+ let mut acc = T::Native::usize_as(0);
+ for distinct_value in self.values.iter() {
+ acc = acc.add_wrapping(distinct_value.0)
+ }
+ let v = (!self.values.is_empty()).then_some(acc);
+ ScalarValue::new_primitive::<T>(v, &self.data_type)
+ }
+
+ fn size(&self) -> usize {
+ std::mem::size_of_val(self)
+ + self.values.capacity() * std::mem::size_of::<T::Native>()
+ }
+}
diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs
b/datafusion/optimizer/src/analyzer/type_coercion.rs
index 081a54ac44..31dc9028b9 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -430,6 +430,13 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> {
&fun.signature(),
)?
}
+ expr::WindowFunctionDefinition::AggregateUDF(udf) => {
+ coerce_arguments_for_signature_with_aggregate_udf(
+ args,
+ self.schema,
+ udf,
+ )?
+ }
_ => args,
};
@@ -985,13 +992,10 @@ mod test {
None,
None,
));
- let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf],
empty)?);
- let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan,
"")
- .err()
- .unwrap();
- assert_eq!(
- "type_coercion\ncaused by\nError during planning: Coercion from
[Utf8] to the signature Uniform(1, [Float64]) failed.",
- err.strip_backtrace()
+
+ let err = Projection::try_new(vec![udaf], empty).err().unwrap();
+ assert!(
+ err.strip_backtrace().starts_with("Error during planning: Error
during planning: Coercion from [Utf8] to the signature Uniform(1, [Float64])
failed")
);
Ok(())
}
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index c87654292a..024cb74403 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -3804,7 +3804,10 @@ mod tests {
unimplemented!("not needed for testing")
}
- fn create_groups_accumulator(&self) -> Result<Box<dyn
GroupsAccumulator>> {
+ fn create_groups_accumulator(
+ &self,
+ _args: AccumulatorArgs,
+ ) -> Result<Box<dyn GroupsAccumulator>> {
unimplemented!("not needed for testing")
}
diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs
b/datafusion/optimizer/src/single_distinct_to_groupby.rs
index 27449c8dd5..06d0dee270 100644
--- a/datafusion/optimizer/src/single_distinct_to_groupby.rs
+++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs
@@ -259,7 +259,7 @@ impl OptimizerRule for SingleDistinctToGroupBy {
}
Expr::AggregateFunction(AggregateFunction {
func_def: AggregateFunctionDefinition::UDF(udf),
- args,
+ mut args,
distinct,
..
}) => {
@@ -267,7 +267,6 @@ impl OptimizerRule for SingleDistinctToGroupBy {
if args.len() != 1 {
return internal_err!("DISTINCT aggregate
should have exactly one argument");
}
- let mut args = args;
let arg = args.swap_remove(0);
if
group_fields_set.insert(arg.display_name()?) {
@@ -298,7 +297,7 @@ impl OptimizerRule for SingleDistinctToGroupBy {
.alias(&alias_str),
);
Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
- udf.clone(),
+ udf,
vec![col(&alias_str)],
false,
None,
diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs
b/datafusion/physical-expr-common/src/aggregate/mod.rs
index 78c7d40b87..2273418c60 100644
--- a/datafusion/physical-expr-common/src/aggregate/mod.rs
+++ b/datafusion/physical-expr-common/src/aggregate/mod.rs
@@ -19,6 +19,14 @@ pub mod groups_accumulator;
pub mod stats;
pub mod utils;
+use arrow::datatypes::{DataType, Field, Schema};
+use datafusion_common::{not_impl_err, Result};
+use datafusion_expr::function::StateFieldsArgs;
+use datafusion_expr::type_coercion::aggregates::check_arg_count;
+use datafusion_expr::ReversedUDAF;
+use datafusion_expr::{
+ function::AccumulatorArgs, Accumulator, AggregateUDF, Expr,
GroupsAccumulator,
+};
use std::fmt::Debug;
use std::{any::Any, sync::Arc};
@@ -27,14 +35,8 @@ use crate::physical_expr::PhysicalExpr;
use crate::sort_expr::{LexOrdering, PhysicalSortExpr};
use crate::utils::reverse_order_bys;
-use arrow::datatypes::{DataType, Field, Schema};
-use datafusion_common::{exec_err, not_impl_err, Result};
-use datafusion_expr::function::StateFieldsArgs;
-use datafusion_expr::type_coercion::aggregates::check_arg_count;
+use datafusion_common::exec_err;
use datafusion_expr::utils::AggregateOrderSensitivity;
-use datafusion_expr::{
- function::AccumulatorArgs, Accumulator, AggregateUDF, Expr,
GroupsAccumulator,
-};
/// Creates a physical expression of the UDAF, that includes all necessary
type coercion.
/// This function errors when `args`' can't be coerced to a valid argument
type of the UDAF.
@@ -50,6 +52,7 @@ pub fn create_aggregate_expr(
is_distinct: bool,
) -> Result<Arc<dyn AggregateExpr>> {
debug_assert_eq!(sort_exprs.len(), ordering_req.len());
+
let input_exprs_types = input_phy_exprs
.iter()
.map(|arg| arg.data_type(schema))
@@ -222,7 +225,7 @@ pub struct AggregatePhysicalExpressions {
}
/// Physical aggregate expression of a UDAF.
-#[derive(Debug)]
+#[derive(Debug, Clone)]
pub struct AggregateFunctionExpr {
fun: AggregateUDF,
args: Vec<Arc<dyn PhysicalExpr>>,
@@ -234,7 +237,9 @@ pub struct AggregateFunctionExpr {
sort_exprs: Vec<Expr>,
// The physical order by expressions
ordering_req: LexOrdering,
+ // Whether to ignore null values
ignore_nulls: bool,
+ // fields used for order sensitive aggregation functions
ordering_fields: Vec<Field>,
is_distinct: bool,
input_type: DataType,
@@ -294,7 +299,18 @@ impl AggregateExpr for AggregateFunctionExpr {
}
fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
- let accumulator = self.create_accumulator()?;
+ let args = AccumulatorArgs {
+ data_type: &self.data_type,
+ schema: &self.schema,
+ ignore_nulls: self.ignore_nulls,
+ sort_exprs: &self.sort_exprs,
+ is_distinct: self.is_distinct,
+ input_type: &self.input_type,
+ args_num: self.args.len(),
+ name: &self.name,
+ };
+
+ let accumulator = self.fun.create_sliding_accumulator(args)?;
// Accumulators that have window frame startings different
// than `UNBOUNDED PRECEDING`, such as `1 PRECEEDING`, need to
@@ -367,11 +383,29 @@ impl AggregateExpr for AggregateFunctionExpr {
}
fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
- self.fun.create_groups_accumulator()
+ let args = AccumulatorArgs {
+ data_type: &self.data_type,
+ schema: &self.schema,
+ ignore_nulls: self.ignore_nulls,
+ sort_exprs: &self.sort_exprs,
+ is_distinct: self.is_distinct,
+ input_type: &self.input_type,
+ args_num: self.args.len(),
+ name: &self.name,
+ };
+ self.fun.create_groups_accumulator(args)
}
fn order_bys(&self) -> Option<&[PhysicalSortExpr]> {
- (!self.ordering_req.is_empty()).then_some(&self.ordering_req)
+ if self.ordering_req.is_empty() {
+ return None;
+ }
+
+ if !self.order_sensitivity().is_insensitive() {
+ return Some(&self.ordering_req);
+ }
+
+ None
}
fn order_sensitivity(&self) -> AggregateOrderSensitivity {
@@ -409,37 +443,41 @@ impl AggregateExpr for AggregateFunctionExpr {
}
fn reverse_expr(&self) -> Option<Arc<dyn AggregateExpr>> {
- if let Some(reverse_udf) = self.fun.reverse_udf() {
- let reverse_ordering_req = reverse_order_bys(&self.ordering_req);
- let reverse_sort_exprs = self
- .sort_exprs
- .iter()
- .map(|e| {
- if let Expr::Sort(s) = e {
- Expr::Sort(s.reverse())
- } else {
- // Expects to receive `Expr::Sort`.
- unreachable!()
- }
- })
- .collect::<Vec<_>>();
- let mut name = self.name().to_string();
- replace_order_by_clause(&mut name);
- replace_fn_name_clause(&mut name, self.fun.name(),
reverse_udf.name());
- let reverse_aggr = create_aggregate_expr(
- &reverse_udf,
- &self.args,
- &reverse_sort_exprs,
- &reverse_ordering_req,
- &self.schema,
- name,
- self.ignore_nulls,
- self.is_distinct,
- )
- .unwrap();
- return Some(reverse_aggr);
+ match self.fun.reverse_udf() {
+ ReversedUDAF::NotSupported => None,
+ ReversedUDAF::Identical => Some(Arc::new(self.clone())),
+ ReversedUDAF::Reversed(reverse_udf) => {
+ let reverse_ordering_req =
reverse_order_bys(&self.ordering_req);
+ let reverse_sort_exprs = self
+ .sort_exprs
+ .iter()
+ .map(|e| {
+ if let Expr::Sort(s) = e {
+ Expr::Sort(s.reverse())
+ } else {
+ // Expects to receive `Expr::Sort`.
+ unreachable!()
+ }
+ })
+ .collect::<Vec<_>>();
+ let mut name = self.name().to_string();
+ replace_order_by_clause(&mut name);
+ replace_fn_name_clause(&mut name, self.fun.name(),
reverse_udf.name());
+ let reverse_aggr = create_aggregate_expr(
+ &reverse_udf,
+ &self.args,
+ &reverse_sort_exprs,
+ &reverse_ordering_req,
+ &self.schema,
+ name,
+ self.ignore_nulls,
+ self.is_distinct,
+ )
+ .unwrap();
+
+ Some(reverse_aggr)
+ }
}
- None
}
}
diff --git a/datafusion/physical-expr-common/src/aggregate/utils.rs
b/datafusion/physical-expr-common/src/aggregate/utils.rs
index c59c29a139..bcd0d05be0 100644
--- a/datafusion/physical-expr-common/src/aggregate/utils.rs
+++ b/datafusion/physical-expr-common/src/aggregate/utils.rs
@@ -17,9 +17,10 @@
use std::{any::Any, sync::Arc};
+use arrow::array::{ArrayRef, AsArray};
use arrow::datatypes::ArrowNativeType;
use arrow::{
- array::{ArrayRef, ArrowNativeTypeOp, AsArray},
+ array::ArrowNativeTypeOp,
compute::SortOptions,
datatypes::{
DataType, Decimal128Type, DecimalType, Field, TimeUnit,
TimestampMicrosecondType,
diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs
b/datafusion/physical-expr/src/aggregate/build_in.rs
index e100089954..813a394d69 100644
--- a/datafusion/physical-expr/src/aggregate/build_in.rs
+++ b/datafusion/physical-expr/src/aggregate/build_in.rs
@@ -28,14 +28,15 @@
use std::sync::Arc;
+use arrow::datatypes::Schema;
+
+use datafusion_common::{exec_err, internal_err, not_impl_err, Result};
+use datafusion_expr::AggregateFunction;
+
use crate::aggregate::regr::RegrType;
use crate::expressions::{self, Literal};
use crate::{AggregateExpr, PhysicalExpr, PhysicalSortExpr};
-use arrow::datatypes::Schema;
-use datafusion_common::{exec_err, not_impl_err, Result};
-use datafusion_expr::AggregateFunction;
-
/// Create a physical aggregation expression.
/// This function errors when `input_phy_exprs`' can't be coerced to a valid
argument type of the aggregation function.
pub fn create_aggregate_expr(
@@ -103,16 +104,9 @@ pub fn create_aggregate_expr(
name,
data_type,
)),
- (AggregateFunction::Sum, false) => Arc::new(expressions::Sum::new(
- input_phy_exprs[0].clone(),
- name,
- input_phy_types[0].clone(),
- )),
- (AggregateFunction::Sum, true) =>
Arc::new(expressions::DistinctSum::new(
- vec![input_phy_exprs[0].clone()],
- name,
- data_type,
- )),
+ (AggregateFunction::Sum, _) => {
+ return internal_err!("Builtin Sum will be removed");
+ }
(AggregateFunction::ApproxDistinct, _) => Arc::new(
expressions::ApproxDistinct::new(input_phy_exprs[0].clone(), name,
data_type),
),
@@ -378,7 +372,7 @@ mod tests {
use crate::expressions::{
try_cast, ApproxDistinct, ApproxMedian, ApproxPercentileCont,
ArrayAgg, Avg,
BitAnd, BitOr, BitXor, BoolAnd, BoolOr, Count, DistinctArrayAgg,
DistinctCount,
- Max, Min, Stddev, Sum, Variance,
+ Max, Min, Stddev, Variance,
};
use datafusion_common::{plan_err, DataFusionError, ScalarValue};
@@ -689,7 +683,7 @@ mod tests {
#[test]
fn test_sum_avg_expr() -> Result<()> {
- let funcs = vec![AggregateFunction::Sum, AggregateFunction::Avg];
+ let funcs = vec![AggregateFunction::Avg];
let data_types = vec![
DataType::UInt32,
DataType::UInt64,
@@ -712,37 +706,13 @@ mod tests {
&input_schema,
"c1",
)?;
- match fun {
- AggregateFunction::Sum => {
- assert!(result_agg_phy_exprs.as_any().is::<Sum>());
- assert_eq!("c1", result_agg_phy_exprs.name());
- let expect_type = match data_type {
- DataType::UInt8
- | DataType::UInt16
- | DataType::UInt32
- | DataType::UInt64 => DataType::UInt64,
- DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64 => DataType::Int64,
- DataType::Float32 | DataType::Float64 =>
DataType::Float64,
- _ => data_type.clone(),
- };
-
- assert_eq!(
- Field::new("c1", expect_type.clone(), true),
- result_agg_phy_exprs.field().unwrap()
- );
- }
- AggregateFunction::Avg => {
- assert!(result_agg_phy_exprs.as_any().is::<Avg>());
- assert_eq!("c1", result_agg_phy_exprs.name());
- assert_eq!(
- Field::new("c1", DataType::Float64, true),
- result_agg_phy_exprs.field().unwrap()
- );
- }
- _ => {}
+ if fun == AggregateFunction::Avg {
+ assert!(result_agg_phy_exprs.as_any().is::<Avg>());
+ assert_eq!("c1", result_agg_phy_exprs.name());
+ assert_eq!(
+ Field::new("c1", DataType::Float64, true),
+ result_agg_phy_exprs.field().unwrap()
+ );
};
}
}
@@ -976,44 +946,6 @@ mod tests {
Ok(())
}
- #[test]
- fn test_sum_return_type() -> Result<()> {
- let observed = AggregateFunction::Sum.return_type(&[DataType::Int32])?;
- assert_eq!(DataType::Int64, observed);
-
- let observed = AggregateFunction::Sum.return_type(&[DataType::UInt8])?;
- assert_eq!(DataType::UInt64, observed);
-
- let observed =
AggregateFunction::Sum.return_type(&[DataType::Float32])?;
- assert_eq!(DataType::Float64, observed);
-
- let observed =
AggregateFunction::Sum.return_type(&[DataType::Float64])?;
- assert_eq!(DataType::Float64, observed);
-
- let observed =
- AggregateFunction::Sum.return_type(&[DataType::Decimal128(10,
5)])?;
- assert_eq!(DataType::Decimal128(20, 5), observed);
-
- let observed =
- AggregateFunction::Sum.return_type(&[DataType::Decimal128(35,
5)])?;
- assert_eq!(DataType::Decimal128(38, 5), observed);
-
- Ok(())
- }
-
- #[test]
- fn test_sum_no_utf8() {
- let observed = AggregateFunction::Sum.return_type(&[DataType::Utf8]);
- assert!(observed.is_err());
- }
-
- #[test]
- fn test_sum_upcasts() -> Result<()> {
- let observed =
AggregateFunction::Sum.return_type(&[DataType::UInt32])?;
- assert_eq!(DataType::UInt64, observed);
- Ok(())
- }
-
#[test]
fn test_count_return_type() -> Result<()> {
let observed =
AggregateFunction::Count.return_type(&[DataType::Utf8])?;
diff --git a/datafusion/physical-plan/src/windows/mod.rs
b/datafusion/physical-plan/src/windows/mod.rs
index 42c630741c..9b392d941e 100644
--- a/datafusion/physical-plan/src/windows/mod.rs
+++ b/datafusion/physical-plan/src/windows/mod.rs
@@ -31,10 +31,11 @@ use crate::{
use arrow::datatypes::Schema;
use arrow_schema::{DataType, Field, SchemaRef};
-use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue};
+use datafusion_common::{exec_err, Column, DataFusionError, Result,
ScalarValue};
+use datafusion_expr::Expr;
use datafusion_expr::{
- BuiltInWindowFunction, PartitionEvaluator, WindowFrame,
WindowFunctionDefinition,
- WindowUDF,
+ BuiltInWindowFunction, PartitionEvaluator, SortExpr, WindowFrame,
+ WindowFunctionDefinition, WindowUDF,
};
use datafusion_physical_expr::equivalence::collapse_lex_req;
use datafusion_physical_expr::{
@@ -70,12 +71,17 @@ pub fn schema_add_window_field(
.iter()
.map(|f| f.as_ref().clone())
.collect_vec();
- window_fields.extend_from_slice(&[Field::new(
- fn_name,
- window_expr_return_type,
- false,
- )]);
- Ok(Arc::new(Schema::new(window_fields)))
+ // Skip extending schema for UDAF
+ if let WindowFunctionDefinition::AggregateUDF(_) = window_fn {
+ Ok(Arc::new(Schema::new(window_fields)))
+ } else {
+ window_fields.extend_from_slice(&[Field::new(
+ fn_name,
+ window_expr_return_type,
+ false,
+ )]);
+ Ok(Arc::new(Schema::new(window_fields)))
+ }
}
/// Create a physical expression for window function
@@ -118,14 +124,28 @@ pub fn create_window_expr(
}
WindowFunctionDefinition::AggregateUDF(fun) => {
// TODO: Ordering not supported for Window UDFs yet
- let sort_exprs = &[];
- let ordering_req = &[];
+ // Convert `Vec<PhysicalSortExpr>` into `Vec<Expr::Sort>`
+ let sort_exprs = order_by
+ .iter()
+ .map(|PhysicalSortExpr { expr, options }| {
+ let field_name = expr.to_string();
+ let field_name =
field_name.split('@').next().unwrap_or(&field_name);
+ Expr::Sort(SortExpr {
+ expr: Box::new(Expr::Column(Column::new(
+ None::<String>,
+ field_name,
+ ))),
+ asc: !options.descending,
+ nulls_first: options.nulls_first,
+ })
+ })
+ .collect::<Vec<_>>();
let aggregate = udaf::create_aggregate_expr(
fun.as_ref(),
args,
- sort_exprs,
- ordering_req,
+ &sort_exprs,
+ order_by,
input_schema,
name,
ignore_nulls,
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index c065948d3b..0408ea91b9 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -871,7 +871,7 @@ message PhysicalWindowExprNode {
oneof window_function {
AggregateFunction aggr_function = 1;
BuiltInWindowFunction built_in_function = 2;
- // udaf = 3
+ string user_defined_aggr_function = 3;
}
repeated PhysicalExprNode args = 4;
repeated PhysicalExprNode partition_by = 5;
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index 7e7a14a5d1..e07fbba27d 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -15965,6 +15965,9 @@ impl serde::Serialize for PhysicalWindowExprNode {
.map_err(|_|
serde::ser::Error::custom(format!("Invalid variant {}", *v)))?;
struct_ser.serialize_field("builtInFunction", &v)?;
}
+
physical_window_expr_node::WindowFunction::UserDefinedAggrFunction(v) => {
+ struct_ser.serialize_field("userDefinedAggrFunction", v)?;
+ }
}
}
struct_ser.end()
@@ -15989,6 +15992,8 @@ impl<'de> serde::Deserialize<'de> for
PhysicalWindowExprNode {
"aggrFunction",
"built_in_function",
"builtInFunction",
+ "user_defined_aggr_function",
+ "userDefinedAggrFunction",
];
#[allow(clippy::enum_variant_names)]
@@ -16000,6 +16005,7 @@ impl<'de> serde::Deserialize<'de> for
PhysicalWindowExprNode {
Name,
AggrFunction,
BuiltInFunction,
+ UserDefinedAggrFunction,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
@@ -16028,6 +16034,7 @@ impl<'de> serde::Deserialize<'de> for
PhysicalWindowExprNode {
"name" => Ok(GeneratedField::Name),
"aggrFunction" | "aggr_function" =>
Ok(GeneratedField::AggrFunction),
"builtInFunction" | "built_in_function" =>
Ok(GeneratedField::BuiltInFunction),
+ "userDefinedAggrFunction" |
"user_defined_aggr_function" => Ok(GeneratedField::UserDefinedAggrFunction),
_ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
}
}
@@ -16097,6 +16104,12 @@ impl<'de> serde::Deserialize<'de> for
PhysicalWindowExprNode {
}
window_function__ =
map_.next_value::<::std::option::Option<BuiltInWindowFunction>>()?.map(|x|
physical_window_expr_node::WindowFunction::BuiltInFunction(x as i32));
}
+ GeneratedField::UserDefinedAggrFunction => {
+ if window_function__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("userDefinedAggrFunction"));
+ }
+ window_function__ =
map_.next_value::<::std::option::Option<_>>()?.map(physical_window_expr_node::WindowFunction::UserDefinedAggrFunction);
+ }
}
}
Ok(PhysicalWindowExprNode {
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index f9138da3ab..c75cb36158 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1345,7 +1345,7 @@ pub struct PhysicalWindowExprNode {
pub window_frame: ::core::option::Option<WindowFrame>,
#[prost(string, tag = "8")]
pub name: ::prost::alloc::string::String,
- #[prost(oneof = "physical_window_expr_node::WindowFunction", tags = "1,
2")]
+ #[prost(oneof = "physical_window_expr_node::WindowFunction", tags = "1, 2,
3")]
pub window_function: ::core::option::Option<
physical_window_expr_node::WindowFunction,
>,
@@ -1357,9 +1357,10 @@ pub mod physical_window_expr_node {
pub enum WindowFunction {
#[prost(enumeration = "super::AggregateFunction", tag = "1")]
AggrFunction(i32),
- /// udaf = 3
#[prost(enumeration = "super::BuiltInWindowFunction", tag = "2")]
BuiltInFunction(i32),
+ #[prost(string, tag = "3")]
+ UserDefinedAggrFunction(::prost::alloc::string::String),
}
}
#[allow(clippy::derive_partial_eq_without_eq)]
diff --git a/datafusion/proto/src/physical_plan/from_proto.rs
b/datafusion/proto/src/physical_plan/from_proto.rs
index cf935e6b83..0a91df568a 100644
--- a/datafusion/proto/src/physical_plan/from_proto.rs
+++ b/datafusion/proto/src/physical_plan/from_proto.rs
@@ -145,8 +145,37 @@ pub fn parse_physical_window_expr(
)
})?;
- let fun: WindowFunctionDefinition =
convert_required!(proto.window_function)?;
+ let fun = if let Some(window_func) = proto.window_function.as_ref() {
+ match window_func {
+
protobuf::physical_window_expr_node::WindowFunction::AggrFunction(n) => {
+ let f = protobuf::AggregateFunction::try_from(*n).map_err(|_| {
+ proto_error(format!(
+ "Received an unknown window aggregate function: {n}"
+ ))
+ })?;
+
+ WindowFunctionDefinition::AggregateFunction(f.into())
+ }
+
protobuf::physical_window_expr_node::WindowFunction::BuiltInFunction(n) => {
+ let f =
protobuf::BuiltInWindowFunction::try_from(*n).map_err(|_| {
+ proto_error(format!(
+ "Received an unknown window builtin function: {n}"
+ ))
+ })?;
+
+ WindowFunctionDefinition::BuiltInWindowFunction(f.into())
+ }
+
protobuf::physical_window_expr_node::WindowFunction::UserDefinedAggrFunction(udaf_name)
=> {
+ let agg_udf = registry.udaf(udaf_name)?;
+ WindowFunctionDefinition::AggregateUDF(agg_udf)
+ }
+ }
+ } else {
+ return Err(proto_error("Missing required field in protobuf"));
+ };
+
let name = proto.name.clone();
+ // TODO: Remove extended_schema if functions are all UDAF
let extended_schema =
schema_add_window_field(&window_node_expr, input_schema, &fun, &name)?;
create_window_expr(
@@ -383,37 +412,6 @@ fn parse_required_physical_expr(
})
}
-impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction>
- for WindowFunctionDefinition
-{
- type Error = DataFusionError;
-
- fn try_from(
- expr: &protobuf::physical_window_expr_node::WindowFunction,
- ) -> Result<Self, Self::Error> {
- match expr {
-
protobuf::physical_window_expr_node::WindowFunction::AggrFunction(n) => {
- let f = protobuf::AggregateFunction::try_from(*n).map_err(|_| {
- proto_error(format!(
- "Received an unknown window aggregate function: {n}"
- ))
- })?;
-
- Ok(WindowFunctionDefinition::AggregateFunction(f.into()))
- }
-
protobuf::physical_window_expr_node::WindowFunction::BuiltInFunction(n) => {
- let f =
protobuf::BuiltInWindowFunction::try_from(*n).map_err(|_| {
- proto_error(format!(
- "Received an unknown window builtin function: {n}"
- ))
- })?;
-
- Ok(WindowFunctionDefinition::BuiltInWindowFunction(f.into()))
- }
- }
- }
-}
-
pub fn parse_protobuf_hash_partitioning(
partitioning: Option<&protobuf::PhysicalHashRepartition>,
registry: &dyn FunctionRegistry,
diff --git a/datafusion/proto/src/physical_plan/to_proto.rs
b/datafusion/proto/src/physical_plan/to_proto.rs
index 3135d09593..0714636141 100644
--- a/datafusion/proto/src/physical_plan/to_proto.rs
+++ b/datafusion/proto/src/physical_plan/to_proto.rs
@@ -186,21 +186,29 @@ pub fn serialize_physical_window_expr(
} else if let Some(sliding_aggr_window_expr) =
expr.downcast_ref::<SlidingAggregateWindowExpr>()
{
- let AggrFn { inner, distinct } =
-
aggr_expr_to_aggr_fn(sliding_aggr_window_expr.get_aggregate_expr().as_ref())?;
+ let aggr_expr = sliding_aggr_window_expr.get_aggregate_expr();
+ if let Some(a) =
aggr_expr.as_any().downcast_ref::<AggregateFunctionExpr>() {
+ physical_window_expr_node::WindowFunction::UserDefinedAggrFunction(
+ a.fun().name().to_string(),
+ )
+ } else {
+ let AggrFn { inner, distinct } = aggr_expr_to_aggr_fn(
+ sliding_aggr_window_expr.get_aggregate_expr().as_ref(),
+ )?;
+
+ if distinct {
+ // TODO
+ return not_impl_err!(
+ "Distinct aggregate functions not supported in window
expressions"
+ );
+ }
- if distinct {
- // TODO
- return not_impl_err!(
- "Distinct aggregate functions not supported in window
expressions"
- );
- }
+ if window_frame.start_bound.is_unbounded() {
+ return Err(DataFusionError::Internal(format!("Invalid
SlidingAggregateWindowExpr = {window_expr:?} with WindowFrame =
{window_frame:?}")));
+ }
- if window_frame.start_bound.is_unbounded() {
- return Err(DataFusionError::Internal(format!("Invalid
SlidingAggregateWindowExpr = {window_expr:?} with WindowFrame =
{window_frame:?}")));
+ physical_window_expr_node::WindowFunction::AggrFunction(inner as
i32)
}
-
- physical_window_expr_node::WindowFunction::AggrFunction(inner as i32)
} else {
return not_impl_err!("WindowExpr not supported: {window_expr:?}");
};
diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
index b756d4688d..14d7227480 100644
--- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
@@ -31,7 +31,8 @@ use datafusion::datasource::TableProvider;
use datafusion::execution::context::SessionState;
use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use datafusion::execution::FunctionRegistry;
-use datafusion::functions_aggregate::covariance::{covar_pop, covar_samp};
+use datafusion::functions_aggregate::expr_fn::{covar_pop, covar_samp,
first_value};
+use datafusion::functions_aggregate::median::median;
use datafusion::prelude::*;
use datafusion::test_util::{TestTableFactory, TestTableProvider};
use datafusion_common::config::{FormatOptions, TableOptions};
@@ -648,6 +649,7 @@ async fn roundtrip_expr_api() -> Result<()> {
first_value(vec![lit(1)], false, None, None, None),
covar_samp(lit(1.5), lit(2.2)),
covar_pop(lit(1.5), lit(2.2)),
+ sum(lit(1)),
median(lit(2)),
];
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index df1995f465..9cf686dbd3 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -21,6 +21,7 @@ use std::sync::Arc;
use std::vec;
use arrow::csv::WriterBuilder;
+use datafusion::functions_aggregate::sum::sum_udaf;
use prost::Message;
use datafusion::arrow::array::ArrayRef;
@@ -47,7 +48,7 @@ use datafusion::physical_plan::analyze::AnalyzeExec;
use datafusion::physical_plan::empty::EmptyExec;
use datafusion::physical_plan::expressions::{
binary, cast, col, in_list, like, lit, Avg, BinaryExpr, Column,
DistinctCount,
- NotExpr, NthValue, PhysicalSortExpr, StringAgg, Sum,
+ NotExpr, NthValue, PhysicalSortExpr, StringAgg,
};
use datafusion::physical_plan::filter::FilterExec;
use datafusion::physical_plan::insert::DataSinkExec;
@@ -296,12 +297,20 @@ fn roundtrip_window() -> Result<()> {
WindowFrameBound::Preceding(ScalarValue::Int64(None)),
);
+ let args = vec![cast(col("a", &schema)?, &schema, DataType::Float64)?];
+ let sum_expr = udaf::create_aggregate_expr(
+ &sum_udaf(),
+ &args,
+ &[],
+ &[],
+ &schema,
+ "SUM(a) RANGE BETWEEN CURRENT ROW AND UNBOUNDED PRECEEDING",
+ false,
+ false,
+ )?;
+
let sliding_aggr_window_expr = Arc::new(SlidingAggregateWindowExpr::new(
- Arc::new(Sum::new(
- cast(col("a", &schema)?, &schema, DataType::Float64)?,
- "SUM(a) RANGE BETWEEN CURRENT ROW AND UNBOUNDED PRECEEDING",
- DataType::Float64,
- )),
+ sum_expr,
&[],
&[],
Arc::new(window_frame),
diff --git a/datafusion/sql/src/expr/function.rs
b/datafusion/sql/src/expr/function.rs
index 1f8492b9ba..81a9b4b772 100644
--- a/datafusion/sql/src/expr/function.rs
+++ b/datafusion/sql/src/expr/function.rs
@@ -297,22 +297,24 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
&self,
name: &str,
) -> Result<WindowFunctionDefinition> {
- expr::find_df_window_func(name)
- // next check user defined aggregates
- .or_else(|| {
- self.context_provider
- .get_aggregate_meta(name)
- .map(WindowFunctionDefinition::AggregateUDF)
- })
- // next check user defined window functions
- .or_else(|| {
- self.context_provider
- .get_window_meta(name)
- .map(WindowFunctionDefinition::WindowUDF)
- })
- .ok_or_else(|| {
- plan_datafusion_err!("There is no window function named
{name}")
- })
+ // check udaf first
+ let udaf = self.context_provider.get_aggregate_meta(name);
+ // Skip first value and last value, since we expect window builtin
first/last value not udaf version
+ if udaf.as_ref().is_some_and(|udaf| {
+ udaf.name() != "first_value" && udaf.name() != "last_value"
+ }) {
+ Ok(WindowFunctionDefinition::AggregateUDF(udaf.unwrap()))
+ } else {
+ expr::find_df_window_func(name)
+ .or_else(|| {
+ self.context_provider
+ .get_window_meta(name)
+ .map(WindowFunctionDefinition::WindowUDF)
+ })
+ .ok_or_else(|| {
+ plan_datafusion_err!("There is no window function named
{name}")
+ })
+ }
}
fn sql_fn_arg_to_logical_expr(
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt
b/datafusion/sqllogictest/test_files/aggregate.slt
index df6a376448..98e64b025b 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -3559,10 +3559,10 @@ NULL NULL NULL NULL NULL NULL NULL NULL Row 2 Y
# aggregate_timestamps_sum
-statement error DataFusion error: Error during planning: No function matches
the given name and argument types 'SUM\(Timestamp\(Nanosecond, None\)\)'\. You
might need to add explicit type casts\.
+query error
SELECT sum(nanos), sum(micros), sum(millis), sum(secs) FROM t;
-statement error DataFusion error: Error during planning: No function matches
the given name and argument types 'SUM\(Timestamp\(Nanosecond, None\)\)'\. You
might need to add explicit type casts\.
+query error
SELECT tag, sum(nanos), sum(micros), sum(millis), sum(secs) FROM t GROUP BY
tag ORDER BY tag;
# aggregate_timestamps_count
@@ -3670,10 +3670,10 @@ NULL NULL Row 2 Y
# aggregate_timestamps_sum
-statement error DataFusion error: Error during planning: No function matches
the given name and argument types 'SUM\(Date32\)'\. You might need to add
explicit type casts\.
+query error
SELECT sum(date32), sum(date64) FROM t;
-statement error DataFusion error: Error during planning: No function matches
the given name and argument types 'SUM\(Date32\)'\. You might need to add
explicit type casts\.
+query error
SELECT tag, sum(date32), sum(date64) FROM t GROUP BY tag ORDER BY tag;
# aggregate_timestamps_count
@@ -3767,10 +3767,10 @@ select * from t;
21:06:28.247821084 21:06:28.247821 21:06:28.247 21:06:28 Row 3 B
# aggregate_times_sum
-statement error DataFusion error: Error during planning: No function matches
the given name and argument types 'SUM\(Time64\(Nanosecond\)\)'\. You might
need to add explicit type casts\.
+query error
SELECT sum(nanos), sum(micros), sum(millis), sum(secs) FROM t
-statement error DataFusion error: Error during planning: No function matches
the given name and argument types 'SUM\(Time64\(Nanosecond\)\)'\. You might
need to add explicit type casts\.
+query error
SELECT tag, sum(nanos), sum(micros), sum(millis), sum(secs) FROM t GROUP BY
tag ORDER BY tag
# aggregate_times_count
diff --git a/datafusion/sqllogictest/test_files/order.slt
b/datafusion/sqllogictest/test_files/order.slt
index d7f10537d0..2678e8cbd1 100644
--- a/datafusion/sqllogictest/test_files/order.slt
+++ b/datafusion/sqllogictest/test_files/order.slt
@@ -1131,4 +1131,4 @@ physical_plan
01)SortPreservingMergeExec: [c@0 ASC NULLS LAST]
02)--ProjectionExec: expr=[CAST(inc_col@0 > desc_col@1 AS Int32) as c]
03)----RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
-04)------CsvExec: file_groups={1 group:
[[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]},
projection=[inc_col, desc_col], output_orderings=[[inc_col@0 ASC NULLS LAST],
[desc_col@1 DESC]], has_header=true
\ No newline at end of file
+04)------CsvExec: file_groups={1 group:
[[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]},
projection=[inc_col, desc_col], output_orderings=[[inc_col@0 ASC NULLS LAST],
[desc_col@1 DESC]], has_header=true
diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt
b/datafusion/sqllogictest/test_files/sort_merge_join.slt
index babb7dc8fd..ce738c7a6f 100644
--- a/datafusion/sqllogictest/test_files/sort_merge_join.slt
+++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt
@@ -344,6 +344,7 @@ t1 as (
select 11 a, 13 b)
select t1.* from t1 where exists (select 1 from t1 t2 where t2.a = t1.a
and t2.b != t1.b)
) order by 1, 2;
+----
query II
select * from (
diff --git a/datafusion/sqllogictest/test_files/unnest.slt
b/datafusion/sqllogictest/test_files/unnest.slt
index bdd7e6631c..8866cd009c 100644
--- a/datafusion/sqllogictest/test_files/unnest.slt
+++ b/datafusion/sqllogictest/test_files/unnest.slt
@@ -65,7 +65,7 @@ select * from unnest(struct(1,2,3));
----
1 2 3
-## Multiple unnest expression in from clause
+## Multiple unnest expression in from clause
query IIII
select * from unnest(struct(1,2,3)),unnest([4,5,6]);
----
@@ -446,7 +446,7 @@ query error DataFusion error: type_coercion\ncaused
by\nThis feature is not impl
select sum(unnest(generate_series(1,10)));
## TODO: support unnest as a child expr
-query error DataFusion error: Internal error: unnest on struct can ony be
applied at the root level of select expression
+query error DataFusion error: Internal error: unnest on struct can ony be
applied at the root level of select expression
select arrow_typeof(unnest(column5)) from unnest_table;
diff --git a/datafusion/sqllogictest/test_files/window.slt
b/datafusion/sqllogictest/test_files/window.slt
index be1517aa75..2d5dd439d7 100644
--- a/datafusion/sqllogictest/test_files/window.slt
+++ b/datafusion/sqllogictest/test_files/window.slt
@@ -1344,16 +1344,16 @@ EXPLAIN SELECT
LIMIT 5
----
logical_plan
-01)Projection: aggregate_test_100.c9, FIRST_VALUE(aggregate_test_100.c9) ORDER
BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING AS fv1, FIRST_VALUE(aggregate_test_100.c9) ORDER BY
[aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING AS fv2, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY
[aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND
CURRENT ROW AS lag1, LAG(aggregate_test_100.c9,Int64( [...]
+01)Projection: aggregate_test_100.c9, first_value(aggregate_test_100.c9) ORDER
BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING AS fv1, first_value(aggregate_test_100.c9) ORDER BY
[aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING AS fv2, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY
[aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND
CURRENT ROW AS lag1, LAG(aggregate_test_100.c9,Int64( [...]
02)--Limit: skip=0, fetch=5
-03)----WindowAggr: windowExpr=[[FIRST_VALUE(aggregate_test_100.c9) ORDER BY
[aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING, LAG(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY
[aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND
CURRENT ROW, LEAD(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY
[aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND
CURRENT ROW]]
-04)------WindowAggr: windowExpr=[[FIRST_VALUE(aggregate_test_100.c9) ORDER BY
[aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING, LAG(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY
[aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1
FOLLOWING, LEAD(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY
[aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1
FOLLOWING]]
+03)----WindowAggr: windowExpr=[[first_value(aggregate_test_100.c9) ORDER BY
[aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING, LAG(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY
[aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND
CURRENT ROW, LEAD(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY
[aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND
CURRENT ROW]]
+04)------WindowAggr: windowExpr=[[first_value(aggregate_test_100.c9) ORDER BY
[aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING, LAG(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY
[aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1
FOLLOWING, LEAD(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY
[aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1
FOLLOWING]]
05)--------TableScan: aggregate_test_100 projection=[c9]
physical_plan
-01)ProjectionExec: expr=[c9@0 as c9, FIRST_VALUE(aggregate_test_100.c9) ORDER
BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING@4 as fv1, FIRST_VALUE(aggregate_test_100.c9) ORDER BY
[aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING@1 as fv2, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY
[aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND
CURRENT ROW@5 as lag1, LAG(aggregate_test_100.c9,I [...]
+01)ProjectionExec: expr=[c9@0 as c9, first_value(aggregate_test_100.c9) ORDER
BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING@4 as fv1, first_value(aggregate_test_100.c9) ORDER BY
[aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING@1 as fv2, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY
[aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND
CURRENT ROW@5 as lag1, LAG(aggregate_test_100.c9,I [...]
02)--GlobalLimitExec: skip=0, fetch=5
-03)----BoundedWindowAggExec: wdw=[FIRST_VALUE(aggregate_test_100.c9) ORDER BY
[aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING: Ok(Field { name: "FIRST_VALUE(aggregate_test_100.c9) ORDER BY
[aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered:
false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound:
Preceding(UInt64(5)), end_bound: Following(UInt64(1)), [...]
-04)------BoundedWindowAggExec: wdw=[FIRST_VALUE(aggregate_test_100.c9) ORDER
BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING: Ok(Field { name: "FIRST_VALUE(aggregate_test_100.c9) ORDER BY
[aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered:
false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound:
Preceding(UInt64(1)), end_bound: Following(UInt6 [...]
+03)----BoundedWindowAggExec: wdw=[first_value(aggregate_test_100.c9) ORDER BY
[aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING: Ok(Field { name: "first_value(aggregate_test_100.c9) ORDER BY
[aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered:
false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound:
Preceding(UInt64(5)), end_bound: Following(UInt64(1)), [...]
+04)------BoundedWindowAggExec: wdw=[first_value(aggregate_test_100.c9) ORDER
BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING: Ok(Field { name: "first_value(aggregate_test_100.c9) ORDER BY
[aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5
FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered:
false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound:
Preceding(UInt64(1)), end_bound: Following(UInt6 [...]
05)--------SortExec: expr=[c9@0 DESC], preserve_partitioning=[false]
06)----------CsvExec: file_groups={1 group:
[[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9],
has_header=true
@@ -2634,16 +2634,16 @@ EXPLAIN SELECT
logical_plan
01)Limit: skip=0, fetch=5
02)--Sort: annotated_data_finite.ts DESC NULLS FIRST, fetch=5
-03)----Projection: annotated_data_finite.ts,
FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts
ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv1,
FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts
ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv2,
LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts
ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv1,
LAST_VALUE(ann [...]
-04)------WindowAggr: windowExpr=[[FIRST_VALUE(annotated_data_finite.inc_col)
ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING
AND 1 FOLLOWING, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY
[annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1
FOLLOWING, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY
[annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1
FOLLOWING, LAST_VALUE(annotated_data_finite.inc_col) ORDE [...]
-05)--------WindowAggr: windowExpr=[[FIRST_VALUE(annotated_data_finite.inc_col)
ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING
AND 1 FOLLOWING, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY
[annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1
FOLLOWING, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY
[annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1
FOLLOWING, LAST_VALUE(annotated_data_finite.inc_c [...]
+03)----Projection: annotated_data_finite.ts,
first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts
ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv1,
first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts
ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv2,
last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts
ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv1,
last_value(ann [...]
+04)------WindowAggr: windowExpr=[[first_value(annotated_data_finite.inc_col)
ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING
AND 1 FOLLOWING, first_value(annotated_data_finite.inc_col) ORDER BY
[annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1
FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY
[annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1
FOLLOWING, last_value(annotated_data_finite.inc_col) ORDE [...]
+05)--------WindowAggr: windowExpr=[[first_value(annotated_data_finite.inc_col)
ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING
AND 1 FOLLOWING, first_value(annotated_data_finite.inc_col) ORDER BY
[annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1
FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY
[annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1
FOLLOWING, last_value(annotated_data_finite.inc_c [...]
06)----------TableScan: annotated_data_finite projection=[ts, inc_col]
physical_plan
01)GlobalLimitExec: skip=0, fetch=5
02)--SortExec: TopK(fetch=5), expr=[ts@0 DESC], preserve_partitioning=[false]
-03)----ProjectionExec: expr=[ts@0 as ts,
FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts
ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@10 as fv1,
FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts
ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@11 as fv2,
LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts
ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@12 as lv1, LAST_VALU
[...]
-04)------BoundedWindowAggExec: wdw=[FIRST_VALUE(annotated_data_finite.inc_col)
ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING
AND 1 FOLLOWING: Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col)
ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING
AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0,
dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range,
start_bound: Preceding(Int32(10)), end_ [...]
-05)--------BoundedWindowAggExec:
wdw=[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY
[annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1
FOLLOWING: Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col) ORDER
BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1
FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered:
false, metadata: {} }), frame: WindowFrame { units: Range, start_bound:
Preceding(Int32(1)), [...]
+03)----ProjectionExec: expr=[ts@0 as ts,
first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts
ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@10 as fv1,
first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts
ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@11 as fv2,
last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts
ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@12 as lv1, last_valu
[...]
+04)------BoundedWindowAggExec: wdw=[first_value(annotated_data_finite.inc_col)
ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING
AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col)
ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING
AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0,
dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range,
start_bound: Preceding(Int32(10)), end_ [...]
+05)--------BoundedWindowAggExec:
wdw=[first_value(annotated_data_finite.inc_col) ORDER BY
[annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1
FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER
BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1
FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered:
false, metadata: {} }), frame: WindowFrame { units: Range, start_bound:
Preceding(Int32(1)), [...]
06)----------CsvExec: file_groups={1 group:
[[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts,
inc_col], output_ordering=[ts@0 ASC NULLS LAST], has_header=true
query IIIIIIIIIIIIIIIIIIIIIIIII
@@ -2761,17 +2761,17 @@ logical_plan
01)Projection: first_value1, first_value2, last_value1, last_value2, nth_value1
02)--Limit: skip=0, fetch=5
03)----Sort: annotated_data_finite.inc_col ASC NULLS LAST, fetch=5
-04)------Projection: FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY
[annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND
1 FOLLOWING AS first_value1, FIRST_VALUE(annotated_data_finite.inc_col) ORDER
BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND
UNBOUNDED FOLLOWING AS first_value2, LAST_VALUE(annotated_data_finite.inc_col)
ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED
PRECEDING AND 1 FOLLOWING AS last [...]
-05)--------WindowAggr: windowExpr=[[FIRST_VALUE(annotated_data_finite.inc_col)
ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED
PRECEDING AND 1 FOLLOWING, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY
[annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND
1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(2)) ORDER BY
[annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND
1 FOLLOWING]]
-06)----------WindowAggr:
windowExpr=[[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY
[annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND
UNBOUNDED FOLLOWING, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY
[annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND
UNBOUNDED FOLLOWING]]
+04)------Projection: first_value(annotated_data_finite.inc_col) ORDER BY
[annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND
1 FOLLOWING AS first_value1, first_value(annotated_data_finite.inc_col) ORDER
BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND
UNBOUNDED FOLLOWING AS first_value2, last_value(annotated_data_finite.inc_col)
ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED
PRECEDING AND 1 FOLLOWING AS last [...]
+05)--------WindowAggr: windowExpr=[[first_value(annotated_data_finite.inc_col)
ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED
PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY
[annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND
1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(2)) ORDER BY
[annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND
1 FOLLOWING]]
+06)----------WindowAggr:
windowExpr=[[first_value(annotated_data_finite.inc_col) ORDER BY
[annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND
UNBOUNDED FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY
[annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND
UNBOUNDED FOLLOWING]]
07)------------TableScan: annotated_data_finite projection=[ts, inc_col]
physical_plan
01)ProjectionExec: expr=[first_value1@0 as first_value1, first_value2@1 as
first_value2, last_value1@2 as last_value1, last_value2@3 as last_value2,
nth_value1@4 as nth_value1]
02)--GlobalLimitExec: skip=0, fetch=5
03)----SortExec: TopK(fetch=5), expr=[inc_col@5 ASC NULLS LAST],
preserve_partitioning=[false]
-04)------ProjectionExec: expr=[FIRST_VALUE(annotated_data_finite.inc_col)
ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED
PRECEDING AND 1 FOLLOWING@4 as first_value1,
FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts
DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@2 as
first_value2, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY
[annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND
1 FOL [...]
-05)--------BoundedWindowAggExec:
wdw=[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY
[annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND
1 FOLLOWING: Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col) ORDER
BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING
AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0,
dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows,
start_bound: Preceding(UIn [...]
-06)----------BoundedWindowAggExec:
wdw=[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY
[annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND
UNBOUNDED FOLLOWING: Ok(Field { name:
"FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts
DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type:
Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }),
frame: WindowFrame { units: Rows, start_bound: Precedi [...]
+04)------ProjectionExec: expr=[first_value(annotated_data_finite.inc_col)
ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED
PRECEDING AND 1 FOLLOWING@4 as first_value1,
first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts
DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@2 as
first_value2, last_value(annotated_data_finite.inc_col) ORDER BY
[annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND
1 FOL [...]
+05)--------BoundedWindowAggExec:
wdw=[first_value(annotated_data_finite.inc_col) ORDER BY
[annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND
1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER
BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING
AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0,
dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows,
start_bound: Preceding(UIn [...]
+06)----------BoundedWindowAggExec:
wdw=[first_value(annotated_data_finite.inc_col) ORDER BY
[annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND
UNBOUNDED FOLLOWING: Ok(Field { name:
"first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts
DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type:
Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }),
frame: WindowFrame { units: Rows, start_bound: Precedi [...]
07)------------CsvExec: file_groups={1 group:
[[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts,
inc_col], output_ordering=[ts@0 ASC NULLS LAST], has_header=true
query IIIII
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]