This is an automated email from the ASF dual-hosted git repository.
jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 8f718dd3ce Move `Count` to `functions-aggregate`, update MSRV to rust
1.75 (#10484)
8f718dd3ce is described below
commit 8f718dd3ce291c9f5688144ca6c9d7d854dc4b0b
Author: Jay Zhan <[email protected]>
AuthorDate: Thu Jun 13 07:54:39 2024 +0800
Move `Count` to `functions-aggregate`, update MSRV to rust 1.75 (#10484)
* mv accumulate indices
Signed-off-by: jayzhan211 <[email protected]>
* complete udaf
Signed-off-by: jayzhan211 <[email protected]>
* register
Signed-off-by: jayzhan211 <[email protected]>
* fix expr
Signed-off-by: jayzhan211 <[email protected]>
* filter distinct count
Signed-off-by: jayzhan211 <[email protected]>
* todo: need to move count distinct too
Signed-off-by: jayzhan211 <[email protected]>
* move code around
Signed-off-by: jayzhan211 <[email protected]>
* move distinct to aggr-crate
Signed-off-by: jayzhan211 <[email protected]>
* replace
Signed-off-by: jayzhan211 <[email protected]>
* backup
Signed-off-by: jayzhan211 <[email protected]>
* fix function name and physical expr
Signed-off-by: jayzhan211 <[email protected]>
* fix physical optimizer
Signed-off-by: jayzhan211 <[email protected]>
* fix all slt
Signed-off-by: jayzhan211 <[email protected]>
* cleanup
Signed-off-by: jayzhan211 <[email protected]>
* cleanup
Signed-off-by: jayzhan211 <[email protected]>
* fix with args
Signed-off-by: jayzhan211 <[email protected]>
* add label
Signed-off-by: jayzhan211 <[email protected]>
* revert builtin related code back
Signed-off-by: jayzhan211 <[email protected]>
* fix test
Signed-off-by: jayzhan211 <[email protected]>
* fix substrait
Signed-off-by: jayzhan211 <[email protected]>
* fix doc
Signed-off-by: jayzhan211 <[email protected]>
* fmy
Signed-off-by: jayzhan211 <[email protected]>
* fix
Signed-off-by: jayzhan211 <[email protected]>
* fix udaf macro for distinct but not apply
Signed-off-by: jayzhan211 <[email protected]>
* fmt
Signed-off-by: jayzhan211 <[email protected]>
* fix count distinct and use workspace
Signed-off-by: jayzhan211 <[email protected]>
* add reverse
Signed-off-by: jayzhan211 <[email protected]>
* remove old code
Signed-off-by: jayzhan211 <[email protected]>
* backup
Signed-off-by: jayzhan211 <[email protected]>
* use macro
Signed-off-by: jayzhan211 <[email protected]>
* expr builder
Signed-off-by: jayzhan211 <[email protected]>
* introduce expr builder
Signed-off-by: jayzhan211 <[email protected]>
* add example
Signed-off-by: jayzhan211 <[email protected]>
* fmt
Signed-off-by: jayzhan211 <[email protected]>
* clean agg sta
Signed-off-by: jayzhan211 <[email protected]>
* combine agg
Signed-off-by: jayzhan211 <[email protected]>
* limit distinct and fmt
Signed-off-by: jayzhan211 <[email protected]>
* cleanup name
Signed-off-by: jayzhan211 <[email protected]>
* fix ci
Signed-off-by: jayzhan211 <[email protected]>
* fix window
Signed-off-by: jayzhan211 <[email protected]>
* fmt
Signed-off-by: jayzhan211 <[email protected]>
* fix ci
Signed-off-by: jayzhan211 <[email protected]>
* fmt
Signed-off-by: jayzhan211 <[email protected]>
* fix merged
Signed-off-by: jayzhan211 <[email protected]>
* fix
Signed-off-by: jayzhan211 <[email protected]>
* fix rebase
Signed-off-by: jayzhan211 <[email protected]>
* cleanup
Signed-off-by: jayzhan211 <[email protected]>
* use std
Signed-off-by: jayzhan211 <[email protected]>
* update mrsv
Signed-off-by: jayzhan211 <[email protected]>
* upd msrv
Signed-off-by: jayzhan211 <[email protected]>
* revert test
Signed-off-by: jayzhan211 <[email protected]>
* fmt
Signed-off-by: jayzhan211 <[email protected]>
* downgrade to 1.75
Signed-off-by: jayzhan211 <[email protected]>
* 1.76
Signed-off-by: jayzhan211 <[email protected]>
* ahas
Signed-off-by: jayzhan211 <[email protected]>
* revert to 1.75
Signed-off-by: jayzhan211 <[email protected]>
* rm count
Signed-off-by: jayzhan211 <[email protected]>
* fix merge
Signed-off-by: jayzhan211 <[email protected]>
* fmt
Signed-off-by: jayzhan211 <[email protected]>
* clippy
Signed-off-by: jayzhan211 <[email protected]>
* rm sum in test_no_duplicate_name
Signed-off-by: jayzhan211 <[email protected]>
* fix
Signed-off-by: jayzhan211 <[email protected]>
---------
Signed-off-by: jayzhan211 <[email protected]>
---
Cargo.toml | 4 +-
datafusion-cli/Cargo.lock | 2 +
datafusion-cli/Cargo.toml | 2 +-
datafusion/core/Cargo.toml | 2 +-
datafusion/core/src/dataframe/mod.rs | 13 +-
.../src/physical_optimizer/aggregate_statistics.rs | 79 +--
.../combine_partial_final_agg.rs | 47 +-
.../limited_distinct_aggregation.rs | 16 +-
.../core/src/physical_optimizer/test_utils.rs | 5 +-
datafusion/core/src/physical_planner.rs | 1 -
.../provider_filter_pushdown.rs | 1 +
datafusion/core/tests/dataframe/mod.rs | 11 +-
datafusion/core/tests/fuzz_cases/window_fuzz.rs | 5 +-
datafusion/expr/src/expr.rs | 2 +-
datafusion/expr/src/expr_fn.rs | 2 +
datafusion/functions-aggregate/src/count.rs | 562 ++++++++++++++++
datafusion/functions-aggregate/src/lib.rs | 8 +-
datafusion/optimizer/src/decorrelate.rs | 10 +-
.../optimizer/src/single_distinct_to_groupby.rs | 3 +-
datafusion/physical-expr-common/Cargo.toml | 2 +
.../src/aggregate/count_distinct/bytes.rs | 6 +-
.../{lib.rs => aggregate/count_distinct/mod.rs} | 12 +-
.../src/aggregate/count_distinct/native.rs | 23 +-
.../physical-expr-common/src/aggregate/mod.rs | 1 +
.../src/binary_map.rs | 21 +-
datafusion/physical-expr-common/src/lib.rs | 1 +
datafusion/physical-expr/src/aggregate/build_in.rs | 92 +--
datafusion/physical-expr/src/aggregate/count.rs | 348 ----------
.../src/aggregate/count_distinct/mod.rs | 718 ---------------------
.../src/aggregate/groups_accumulator/mod.rs | 2 +-
datafusion/physical-expr/src/aggregate/mod.rs | 2 -
datafusion/physical-expr/src/expressions/mod.rs | 2 -
datafusion/physical-expr/src/lib.rs | 4 +-
.../src/aggregates/group_values/bytes.rs | 2 +-
datafusion/physical-plan/src/aggregates/mod.rs | 19 +-
.../src/windows/bounded_window_agg_exec.rs | 7 +-
datafusion/physical-plan/src/windows/mod.rs | 4 +-
datafusion/proto-common/Cargo.toml | 2 +-
datafusion/proto-common/gen/Cargo.toml | 2 +-
datafusion/proto/Cargo.toml | 2 +-
datafusion/proto/gen/Cargo.toml | 2 +-
datafusion/proto/proto/datafusion.proto | 1 +
datafusion/proto/src/generated/pbjson.rs | 17 +
datafusion/proto/src/generated/prost.rs | 2 +
datafusion/proto/src/logical_plan/from_proto.rs | 2 +-
datafusion/proto/src/logical_plan/to_proto.rs | 1 +
datafusion/proto/src/physical_plan/to_proto.rs | 15 +-
.../proto/tests/cases/roundtrip_logical_plan.rs | 2 +
.../proto/tests/cases/roundtrip_physical_plan.rs | 33 +-
datafusion/sqllogictest/test_files/errors.slt | 4 +-
datafusion/substrait/Cargo.toml | 2 +-
datafusion/substrait/src/logical_plan/consumer.rs | 12 +-
52 files changed, 805 insertions(+), 1335 deletions(-)
diff --git a/Cargo.toml b/Cargo.toml
index 65ef191d74..aa1ba1f214 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -52,7 +52,7 @@ homepage = "https://datafusion.apache.org"
license = "Apache-2.0"
readme = "README.md"
repository = "https://github.com/apache/datafusion"
-rust-version = "1.73"
+rust-version = "1.75"
version = "39.0.0"
[workspace.dependencies]
@@ -107,7 +107,7 @@ doc-comment = "0.3"
env_logger = "0.11"
futures = "0.3"
half = { version = "2.2.1", default-features = false }
-hashbrown = { version = "0.14", features = ["raw"] }
+hashbrown = { version = "0.14.5", features = ["raw"] }
indexmap = "2.0.0"
itertools = "0.12"
log = "^0.4"
diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock
index 932f44d984..c5b34df4f1 100644
--- a/datafusion-cli/Cargo.lock
+++ b/datafusion-cli/Cargo.lock
@@ -1376,9 +1376,11 @@ dependencies = [
name = "datafusion-physical-expr-common"
version = "39.0.0"
dependencies = [
+ "ahash",
"arrow",
"datafusion-common",
"datafusion-expr",
+ "hashbrown 0.14.5",
"rand",
]
diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml
index 5e393246b9..8f4b3cd81f 100644
--- a/datafusion-cli/Cargo.toml
+++ b/datafusion-cli/Cargo.toml
@@ -26,7 +26,7 @@ license = "Apache-2.0"
homepage = "https://datafusion.apache.org"
repository = "https://github.com/apache/datafusion"
# Specify MSRV here as `cargo msrv` doesn't support workspace version
-rust-version = "1.73"
+rust-version = "1.75"
readme = "README.md"
[dependencies]
diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml
index 7533e2cff1..45617d88dc 100644
--- a/datafusion/core/Cargo.toml
+++ b/datafusion/core/Cargo.toml
@@ -30,7 +30,7 @@ authors = { workspace = true }
# Specify MSRV here as `cargo msrv` doesn't support workspace version and
fails with
# "Unable to find key 'package.rust-version' (or 'package.metadata.msrv') in
'arrow-datafusion/Cargo.toml'"
# https://github.com/foresterre/cargo-msrv/issues/590
-rust-version = "1.73"
+rust-version = "1.75"
[lints]
workspace = true
diff --git a/datafusion/core/src/dataframe/mod.rs
b/datafusion/core/src/dataframe/mod.rs
index 06a85d3036..950cb7ddb2 100644
--- a/datafusion/core/src/dataframe/mod.rs
+++ b/datafusion/core/src/dataframe/mod.rs
@@ -50,12 +50,11 @@ use datafusion_common::{
};
use datafusion_expr::lit;
use datafusion_expr::{
- avg, count, max, min, utils::COUNT_STAR_EXPANSION,
TableProviderFilterPushDown,
+ avg, max, min, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown,
UNNAMED_TABLE,
};
use datafusion_expr::{case, is_null};
-use datafusion_functions_aggregate::expr_fn::sum;
-use datafusion_functions_aggregate::expr_fn::{median, stddev};
+use datafusion_functions_aggregate::expr_fn::{count, median, stddev, sum};
use async_trait::async_trait;
@@ -854,10 +853,7 @@ impl DataFrame {
/// ```
pub async fn count(self) -> Result<usize> {
let rows = self
- .aggregate(
- vec![],
-
vec![datafusion_expr::count(Expr::Literal(COUNT_STAR_EXPANSION))],
- )?
+ .aggregate(vec![],
vec![count(Expr::Literal(COUNT_STAR_EXPANSION))])?
.collect()
.await?;
let len = *rows
@@ -1594,9 +1590,10 @@ mod tests {
use datafusion_common::{Constraint, Constraints};
use datafusion_common_runtime::SpawnedTask;
use datafusion_expr::{
- array_agg, cast, count_distinct, create_udf, expr, lit,
BuiltInWindowFunction,
+ array_agg, cast, create_udf, expr, lit, BuiltInWindowFunction,
ScalarFunctionImplementation, Volatility, WindowFrame,
WindowFunctionDefinition,
};
+ use datafusion_functions_aggregate::expr_fn::count_distinct;
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties};
diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
index 05f05d95b8..eeacc48b85 100644
--- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
+++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
@@ -170,38 +170,6 @@ fn take_optimizable_column_and_table_count(
}
}
}
- // TODO: Remove this after revmoing Builtin Count
- else if let (&Precision::Exact(num_rows), Some(casted_expr)) = (
- &stats.num_rows,
- agg_expr.as_any().downcast_ref::<expressions::Count>(),
- ) {
- // TODO implementing Eq on PhysicalExpr would help a lot here
- if casted_expr.expressions().len() == 1 {
- // TODO optimize with exprs other than Column
- if let Some(col_expr) = casted_expr.expressions()[0]
- .as_any()
- .downcast_ref::<expressions::Column>()
- {
- let current_val = &col_stats[col_expr.index()].null_count;
- if let &Precision::Exact(val) = current_val {
- return Some((
- ScalarValue::Int64(Some((num_rows - val) as i64)),
- casted_expr.name().to_string(),
- ));
- }
- } else if let Some(lit_expr) = casted_expr.expressions()[0]
- .as_any()
- .downcast_ref::<expressions::Literal>()
- {
- if lit_expr.value() == &COUNT_STAR_EXPANSION {
- return Some((
- ScalarValue::Int64(Some(num_rows as i64)),
- casted_expr.name().to_owned(),
- ));
- }
- }
- }
- }
None
}
@@ -307,13 +275,12 @@ fn take_optimizable_max(
#[cfg(test)]
pub(crate) mod tests {
-
use super::*;
+
use crate::logical_expr::Operator;
use crate::physical_plan::aggregates::PhysicalGroupBy;
use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use crate::physical_plan::common;
- use crate::physical_plan::expressions::Count;
use crate::physical_plan::filter::FilterExec;
use crate::physical_plan::memory::MemoryExec;
use crate::prelude::SessionContext;
@@ -322,8 +289,10 @@ pub(crate) mod tests {
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use datafusion_common::cast::as_int64_array;
+ use datafusion_functions_aggregate::count::count_udaf;
use datafusion_physical_expr::expressions::cast;
use datafusion_physical_expr::PhysicalExpr;
+ use datafusion_physical_expr_common::aggregate::create_aggregate_expr;
use datafusion_physical_plan::aggregates::AggregateMode;
/// Mock data using a MemoryExec which has an exact count statistic
@@ -414,13 +383,19 @@ pub(crate) mod tests {
Self::ColumnA(schema.clone())
}
- /// Return appropriate expr depending if COUNT is for col or table (*)
- pub(crate) fn count_expr(&self) -> Arc<dyn AggregateExpr> {
- Arc::new(Count::new(
- self.column(),
+ // Return appropriate expr depending if COUNT is for col or table (*)
+ pub(crate) fn count_expr(&self, schema: &Schema) -> Arc<dyn
AggregateExpr> {
+ create_aggregate_expr(
+ &count_udaf(),
+ &[self.column()],
+ &[],
+ &[],
+ schema,
self.column_name(),
- DataType::Int64,
- ))
+ false,
+ false,
+ )
+ .unwrap()
}
/// what argument would this aggregate need in the plan?
@@ -458,7 +433,7 @@ pub(crate) mod tests {
let partial_agg = AggregateExec::try_new(
AggregateMode::Partial,
PhysicalGroupBy::default(),
- vec![agg.count_expr()],
+ vec![agg.count_expr(&schema)],
vec![None],
source,
Arc::clone(&schema),
@@ -467,7 +442,7 @@ pub(crate) mod tests {
let final_agg = AggregateExec::try_new(
AggregateMode::Final,
PhysicalGroupBy::default(),
- vec![agg.count_expr()],
+ vec![agg.count_expr(&schema)],
vec![None],
Arc::new(partial_agg),
Arc::clone(&schema),
@@ -488,7 +463,7 @@ pub(crate) mod tests {
let partial_agg = AggregateExec::try_new(
AggregateMode::Partial,
PhysicalGroupBy::default(),
- vec![agg.count_expr()],
+ vec![agg.count_expr(&schema)],
vec![None],
source,
Arc::clone(&schema),
@@ -497,7 +472,7 @@ pub(crate) mod tests {
let final_agg = AggregateExec::try_new(
AggregateMode::Final,
PhysicalGroupBy::default(),
- vec![agg.count_expr()],
+ vec![agg.count_expr(&schema)],
vec![None],
Arc::new(partial_agg),
Arc::clone(&schema),
@@ -517,7 +492,7 @@ pub(crate) mod tests {
let partial_agg = AggregateExec::try_new(
AggregateMode::Partial,
PhysicalGroupBy::default(),
- vec![agg.count_expr()],
+ vec![agg.count_expr(&schema)],
vec![None],
source,
Arc::clone(&schema),
@@ -529,7 +504,7 @@ pub(crate) mod tests {
let final_agg = AggregateExec::try_new(
AggregateMode::Final,
PhysicalGroupBy::default(),
- vec![agg.count_expr()],
+ vec![agg.count_expr(&schema)],
vec![None],
Arc::new(coalesce),
Arc::clone(&schema),
@@ -549,7 +524,7 @@ pub(crate) mod tests {
let partial_agg = AggregateExec::try_new(
AggregateMode::Partial,
PhysicalGroupBy::default(),
- vec![agg.count_expr()],
+ vec![agg.count_expr(&schema)],
vec![None],
source,
Arc::clone(&schema),
@@ -561,7 +536,7 @@ pub(crate) mod tests {
let final_agg = AggregateExec::try_new(
AggregateMode::Final,
PhysicalGroupBy::default(),
- vec![agg.count_expr()],
+ vec![agg.count_expr(&schema)],
vec![None],
Arc::new(coalesce),
Arc::clone(&schema),
@@ -592,7 +567,7 @@ pub(crate) mod tests {
let partial_agg = AggregateExec::try_new(
AggregateMode::Partial,
PhysicalGroupBy::default(),
- vec![agg.count_expr()],
+ vec![agg.count_expr(&schema)],
vec![None],
filter,
Arc::clone(&schema),
@@ -601,7 +576,7 @@ pub(crate) mod tests {
let final_agg = AggregateExec::try_new(
AggregateMode::Final,
PhysicalGroupBy::default(),
- vec![agg.count_expr()],
+ vec![agg.count_expr(&schema)],
vec![None],
Arc::new(partial_agg),
Arc::clone(&schema),
@@ -637,7 +612,7 @@ pub(crate) mod tests {
let partial_agg = AggregateExec::try_new(
AggregateMode::Partial,
PhysicalGroupBy::default(),
- vec![agg.count_expr()],
+ vec![agg.count_expr(&schema)],
vec![None],
filter,
Arc::clone(&schema),
@@ -646,7 +621,7 @@ pub(crate) mod tests {
let final_agg = AggregateExec::try_new(
AggregateMode::Final,
PhysicalGroupBy::default(),
- vec![agg.count_expr()],
+ vec![agg.count_expr(&schema)],
vec![None],
Arc::new(partial_agg),
Arc::clone(&schema),
diff --git
a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs
b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs
index 3ad61e52c8..38b92959e8 100644
--- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs
+++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs
@@ -206,8 +206,9 @@ mod tests {
use crate::physical_plan::{displayable, Partitioning};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
+ use datafusion_functions_aggregate::count::count_udaf;
use datafusion_functions_aggregate::sum::sum_udaf;
- use datafusion_physical_expr::expressions::{col, Count};
+ use datafusion_physical_expr::expressions::col;
use datafusion_physical_plan::udaf::create_aggregate_expr;
/// Runs the CombinePartialFinalAggregate optimizer and asserts the plan
against the expected
@@ -303,15 +304,31 @@ mod tests {
)
}
+ // Return appropriate expr depending if COUNT is for col or table (*)
+ fn count_expr(
+ expr: Arc<dyn PhysicalExpr>,
+ name: &str,
+ schema: &Schema,
+ ) -> Arc<dyn AggregateExpr> {
+ create_aggregate_expr(
+ &count_udaf(),
+ &[expr],
+ &[],
+ &[],
+ schema,
+ name,
+ false,
+ false,
+ )
+ .unwrap()
+ }
+
#[test]
fn aggregations_not_combined() -> Result<()> {
let schema = schema();
- let aggr_expr = vec![Arc::new(Count::new(
- lit(1i8),
- "COUNT(1)".to_string(),
- DataType::Int64,
- )) as _];
+ let aggr_expr = vec![count_expr(lit(1i8), "COUNT(1)", &schema)];
+
let plan = final_aggregate_exec(
repartition_exec(partial_aggregate_exec(
parquet_exec(&schema),
@@ -330,16 +347,8 @@ mod tests {
];
assert_optimized!(expected, plan);
- let aggr_expr1 = vec![Arc::new(Count::new(
- lit(1i8),
- "COUNT(1)".to_string(),
- DataType::Int64,
- )) as _];
- let aggr_expr2 = vec![Arc::new(Count::new(
- lit(1i8),
- "COUNT(2)".to_string(),
- DataType::Int64,
- )) as _];
+ let aggr_expr1 = vec![count_expr(lit(1i8), "COUNT(1)", &schema)];
+ let aggr_expr2 = vec![count_expr(lit(1i8), "COUNT(2)", &schema)];
let plan = final_aggregate_exec(
partial_aggregate_exec(
@@ -365,11 +374,7 @@ mod tests {
#[test]
fn aggregations_combined() -> Result<()> {
let schema = schema();
- let aggr_expr = vec![Arc::new(Count::new(
- lit(1i8),
- "COUNT(1)".to_string(),
- DataType::Int64,
- )) as _];
+ let aggr_expr = vec![count_expr(lit(1i8), "COUNT(1)", &schema)];
let plan = final_aggregate_exec(
partial_aggregate_exec(
diff --git
a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs
b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs
index 1274fbe50a..f9d5a4c186 100644
--- a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs
+++ b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs
@@ -517,10 +517,10 @@ mod tests {
let single_agg = AggregateExec::try_new(
AggregateMode::Single,
build_group_by(&schema.clone(), vec!["a".to_string()]),
- vec![agg.count_expr()], /* aggr_expr */
- vec![None], /* filter_expr */
- source, /* input */
- schema.clone(), /* input_schema */
+ vec![agg.count_expr(&schema)], /* aggr_expr */
+ vec![None], /* filter_expr */
+ source, /* input */
+ schema.clone(), /* input_schema */
)?;
let limit_exec = LocalLimitExec::new(
Arc::new(single_agg),
@@ -554,10 +554,10 @@ mod tests {
let single_agg = AggregateExec::try_new(
AggregateMode::Single,
build_group_by(&schema.clone(), vec!["a".to_string()]),
- vec![agg.count_expr()], /* aggr_expr */
- vec![filter_expr], /* filter_expr */
- source, /* input */
- schema.clone(), /* input_schema */
+ vec![agg.count_expr(&schema)], /* aggr_expr */
+ vec![filter_expr], /* filter_expr */
+ source, /* input */
+ schema.clone(), /* input_schema */
)?;
let limit_exec = LocalLimitExec::new(
Arc::new(single_agg),
diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs
b/datafusion/core/src/physical_optimizer/test_utils.rs
index 5895c39a5f..154e77cd23 100644
--- a/datafusion/core/src/physical_optimizer/test_utils.rs
+++ b/datafusion/core/src/physical_optimizer/test_utils.rs
@@ -43,7 +43,8 @@ use arrow_schema::{Schema, SchemaRef, SortOptions};
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::JoinType;
use datafusion_execution::object_store::ObjectStoreUrl;
-use datafusion_expr::{AggregateFunction, WindowFrame,
WindowFunctionDefinition};
+use datafusion_expr::{WindowFrame, WindowFunctionDefinition};
+use datafusion_functions_aggregate::count::count_udaf;
use datafusion_physical_expr::expressions::col;
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
use datafusion_physical_plan::displayable;
@@ -240,7 +241,7 @@ pub fn bounded_window_exec(
Arc::new(
crate::physical_plan::windows::BoundedWindowAggExec::try_new(
vec![create_window_expr(
-
&WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count),
+ &WindowFunctionDefinition::AggregateUDF(count_udaf()),
"count".to_owned(),
&[col(col_name, &schema).unwrap()],
&[],
diff --git a/datafusion/core/src/physical_planner.rs
b/datafusion/core/src/physical_planner.rs
index 79033643cf..4f91875950 100644
--- a/datafusion/core/src/physical_planner.rs
+++ b/datafusion/core/src/physical_planner.rs
@@ -2181,7 +2181,6 @@ impl DefaultPhysicalPlanner {
expr: &[Expr],
) -> Result<Arc<dyn ExecutionPlan>> {
let input_schema = input.as_ref().schema();
-
let physical_exprs = expr
.iter()
.map(|e| {
diff --git
a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs
b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs
index 8c9cffcf08..068383b200 100644
--- a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs
+++ b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs
@@ -35,6 +35,7 @@ use datafusion::scalar::ScalarValue;
use datafusion_common::cast::as_primitive_array;
use datafusion_common::{internal_err, not_impl_err};
use datafusion_expr::expr::{BinaryExpr, Cast};
+use datafusion_functions_aggregate::expr_fn::count;
use datafusion_physical_expr::EquivalenceProperties;
use async_trait::async_trait;
diff --git a/datafusion/core/tests/dataframe/mod.rs
b/datafusion/core/tests/dataframe/mod.rs
index befd98d043..fa364c5f2a 100644
--- a/datafusion/core/tests/dataframe/mod.rs
+++ b/datafusion/core/tests/dataframe/mod.rs
@@ -31,6 +31,7 @@ use arrow::{
};
use arrow_array::Float32Array;
use arrow_schema::ArrowError;
+use datafusion_functions_aggregate::count::count_udaf;
use object_store::local::LocalFileSystem;
use std::fs;
use std::sync::Arc;
@@ -51,11 +52,11 @@ use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_expr::expr::{GroupingSet, Sort};
use datafusion_expr::var_provider::{VarProvider, VarType};
use datafusion_expr::{
- array_agg, avg, cast, col, count, exists, expr, in_subquery, lit, max,
out_ref_col,
- placeholder, scalar_subquery, when, wildcard, AggregateFunction, Expr,
ExprSchemable,
- WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
+ array_agg, avg, cast, col, exists, expr, in_subquery, lit, max,
out_ref_col,
+ placeholder, scalar_subquery, when, wildcard, Expr, ExprSchemable,
WindowFrame,
+ WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
};
-use datafusion_functions_aggregate::expr_fn::sum;
+use datafusion_functions_aggregate::expr_fn::{count, sum};
#[tokio::test]
async fn test_count_wildcard_on_sort() -> Result<()> {
@@ -178,7 +179,7 @@ async fn test_count_wildcard_on_window() -> Result<()> {
.table("t1")
.await?
.select(vec![Expr::WindowFunction(expr::WindowFunction::new(
-
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count),
+ WindowFunctionDefinition::AggregateUDF(count_udaf()),
vec![wildcard()],
vec![],
vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))],
diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs
b/datafusion/core/tests/fuzz_cases/window_fuzz.rs
index b85f6376c3..4358691ee5 100644
--- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs
@@ -38,6 +38,7 @@ use datafusion_expr::{
AggregateFunction, BuiltInWindowFunction, WindowFrame, WindowFrameBound,
WindowFrameUnits, WindowFunctionDefinition,
};
+use datafusion_functions_aggregate::count::count_udaf;
use datafusion_functions_aggregate::sum::sum_udaf;
use datafusion_physical_expr::expressions::{cast, col, lit};
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
@@ -165,7 +166,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> {
// )
(
// Window function
-
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count),
+ WindowFunctionDefinition::AggregateUDF(count_udaf()),
// its name
"COUNT",
// window function argument
@@ -350,7 +351,7 @@ fn get_random_function(
window_fn_map.insert(
"count",
(
-
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count),
+ WindowFunctionDefinition::AggregateUDF(count_udaf()),
vec![arg.clone()],
),
);
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 98ab8ec251..57f5414c13 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -1861,6 +1861,7 @@ fn write_name<W: Write>(w: &mut W, e: &Expr) ->
Result<()> {
null_treatment,
}) => {
write_function_name(w, &fun.to_string(), false, args)?;
+
if let Some(nt) = null_treatment {
w.write_str(" ")?;
write!(w, "{}", nt)?;
@@ -1885,7 +1886,6 @@ fn write_name<W: Write>(w: &mut W, e: &Expr) ->
Result<()> {
null_treatment,
}) => {
write_function_name(w, func_def.name(), *distinct, args)?;
-
if let Some(fe) = filter {
write!(w, " FILTER (WHERE {fe})")?;
};
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index 4203120508..1fafc63e96 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -193,6 +193,7 @@ pub fn avg(expr: Expr) -> Expr {
}
/// Create an expression to represent the count() aggregate function
+// TODO: Remove this and use `expr_fn::count` instead
pub fn count(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
aggregate_function::AggregateFunction::Count,
@@ -250,6 +251,7 @@ pub fn bitwise_shift_left(left: Expr, right: Expr) -> Expr {
}
/// Create an expression to represent the count(distinct) aggregate function
+// TODO: Remove this and use `expr_fn::count_distinct` instead
pub fn count_distinct(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
aggregate_function::AggregateFunction::Count,
diff --git a/datafusion/functions-aggregate/src/count.rs
b/datafusion/functions-aggregate/src/count.rs
new file mode 100644
index 0000000000..cfd5661953
--- /dev/null
+++ b/datafusion/functions-aggregate/src/count.rs
@@ -0,0 +1,562 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use ahash::RandomState;
+use std::collections::HashSet;
+use std::ops::BitAnd;
+use std::{fmt::Debug, sync::Arc};
+
+use arrow::{
+ array::{ArrayRef, AsArray},
+ datatypes::{
+ DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type,
Field,
+ Float16Type, Float32Type, Float64Type, Int16Type, Int32Type,
Int64Type, Int8Type,
+ Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
+ Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
+ TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
+ UInt16Type, UInt32Type, UInt64Type, UInt8Type,
+ },
+};
+
+use arrow::{
+ array::{Array, BooleanArray, Int64Array, PrimitiveArray},
+ buffer::BooleanBuffer,
+};
+use datafusion_common::{
+ downcast_value, internal_err, DataFusionError, Result, ScalarValue,
+};
+use datafusion_expr::function::StateFieldsArgs;
+use datafusion_expr::{
+ function::AccumulatorArgs, utils::format_state_name, Accumulator,
AggregateUDFImpl,
+ EmitTo, GroupsAccumulator, Signature, Volatility,
+};
+use datafusion_expr::{Expr, ReversedUDAF};
+use
datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::accumulate_indices;
+use datafusion_physical_expr_common::{
+ aggregate::count_distinct::{
+ BytesDistinctCountAccumulator, FloatDistinctCountAccumulator,
+ PrimitiveDistinctCountAccumulator,
+ },
+ binary_map::OutputType,
+};
+
+make_udaf_expr_and_func!(
+ Count,
+ count,
+ expr,
+ "Count the number of non-null values in the column",
+ count_udaf
+);
+
+pub fn count_distinct(expr: Expr) -> datafusion_expr::Expr {
+ datafusion_expr::Expr::AggregateFunction(
+ datafusion_expr::expr::AggregateFunction::new_udf(
+ count_udaf(),
+ vec![expr],
+ true,
+ None,
+ None,
+ None,
+ ),
+ )
+}
+
+pub struct Count {
+ signature: Signature,
+ aliases: Vec<String>,
+}
+
+impl Debug for Count {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ f.debug_struct("Count")
+ .field("name", &self.name())
+ .field("signature", &self.signature)
+ .finish()
+ }
+}
+
+impl Default for Count {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+impl Count {
+ pub fn new() -> Self {
+ Self {
+ aliases: vec!["count".to_string()],
+ signature: Signature::variadic_any(Volatility::Immutable),
+ }
+ }
+}
+
+impl AggregateUDFImpl for Count {
+ fn as_any(&self) -> &dyn std::any::Any {
+ self
+ }
+
+ fn name(&self) -> &str {
+ "COUNT"
+ }
+
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+
+ fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+ Ok(DataType::Int64)
+ }
+
+ fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
+ if args.is_distinct {
+ Ok(vec![Field::new_list(
+ format_state_name(args.name, "count distinct"),
+ Field::new("item", args.input_type.clone(), true),
+ false,
+ )])
+ } else {
+ Ok(vec![Field::new(
+ format_state_name(args.name, "count"),
+ DataType::Int64,
+ true,
+ )])
+ }
+ }
+
+ fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn
Accumulator>> {
+ if !acc_args.is_distinct {
+ return Ok(Box::new(CountAccumulator::new()));
+ }
+
+ let data_type = acc_args.input_type;
+ Ok(match data_type {
+ // try and use a specialized accumulator if possible, otherwise
fall back to generic accumulator
+ DataType::Int8 => Box::new(
+ PrimitiveDistinctCountAccumulator::<Int8Type>::new(data_type),
+ ),
+ DataType::Int16 => Box::new(
+ PrimitiveDistinctCountAccumulator::<Int16Type>::new(data_type),
+ ),
+ DataType::Int32 => Box::new(
+ PrimitiveDistinctCountAccumulator::<Int32Type>::new(data_type),
+ ),
+ DataType::Int64 => Box::new(
+ PrimitiveDistinctCountAccumulator::<Int64Type>::new(data_type),
+ ),
+ DataType::UInt8 => Box::new(
+ PrimitiveDistinctCountAccumulator::<UInt8Type>::new(data_type),
+ ),
+ DataType::UInt16 => Box::new(
+
PrimitiveDistinctCountAccumulator::<UInt16Type>::new(data_type),
+ ),
+ DataType::UInt32 => Box::new(
+
PrimitiveDistinctCountAccumulator::<UInt32Type>::new(data_type),
+ ),
+ DataType::UInt64 => Box::new(
+
PrimitiveDistinctCountAccumulator::<UInt64Type>::new(data_type),
+ ),
+ DataType::Decimal128(_, _) =>
Box::new(PrimitiveDistinctCountAccumulator::<
+ Decimal128Type,
+ >::new(data_type)),
+ DataType::Decimal256(_, _) =>
Box::new(PrimitiveDistinctCountAccumulator::<
+ Decimal256Type,
+ >::new(data_type)),
+
+ DataType::Date32 => Box::new(
+
PrimitiveDistinctCountAccumulator::<Date32Type>::new(data_type),
+ ),
+ DataType::Date64 => Box::new(
+
PrimitiveDistinctCountAccumulator::<Date64Type>::new(data_type),
+ ),
+ DataType::Time32(TimeUnit::Millisecond) => Box::new(
+
PrimitiveDistinctCountAccumulator::<Time32MillisecondType>::new(
+ data_type,
+ ),
+ ),
+ DataType::Time32(TimeUnit::Second) => Box::new(
+
PrimitiveDistinctCountAccumulator::<Time32SecondType>::new(data_type),
+ ),
+ DataType::Time64(TimeUnit::Microsecond) => Box::new(
+
PrimitiveDistinctCountAccumulator::<Time64MicrosecondType>::new(
+ data_type,
+ ),
+ ),
+ DataType::Time64(TimeUnit::Nanosecond) => Box::new(
+
PrimitiveDistinctCountAccumulator::<Time64NanosecondType>::new(data_type),
+ ),
+ DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new(
+
PrimitiveDistinctCountAccumulator::<TimestampMicrosecondType>::new(
+ data_type,
+ ),
+ ),
+ DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new(
+
PrimitiveDistinctCountAccumulator::<TimestampMillisecondType>::new(
+ data_type,
+ ),
+ ),
+ DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new(
+
PrimitiveDistinctCountAccumulator::<TimestampNanosecondType>::new(
+ data_type,
+ ),
+ ),
+ DataType::Timestamp(TimeUnit::Second, _) => Box::new(
+
PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new(data_type),
+ ),
+
+ DataType::Float16 => {
+ Box::new(FloatDistinctCountAccumulator::<Float16Type>::new())
+ }
+ DataType::Float32 => {
+ Box::new(FloatDistinctCountAccumulator::<Float32Type>::new())
+ }
+ DataType::Float64 => {
+ Box::new(FloatDistinctCountAccumulator::<Float64Type>::new())
+ }
+
+ DataType::Utf8 => {
+
Box::new(BytesDistinctCountAccumulator::<i32>::new(OutputType::Utf8))
+ }
+ DataType::LargeUtf8 => {
+
Box::new(BytesDistinctCountAccumulator::<i64>::new(OutputType::Utf8))
+ }
+ DataType::Binary =>
Box::new(BytesDistinctCountAccumulator::<i32>::new(
+ OutputType::Binary,
+ )),
+ DataType::LargeBinary =>
Box::new(BytesDistinctCountAccumulator::<i64>::new(
+ OutputType::Binary,
+ )),
+
+ // Use the generic accumulator based on `ScalarValue` for all
other types
+ _ => Box::new(DistinctCountAccumulator {
+ values: HashSet::default(),
+ state_data_type: data_type.clone(),
+ }),
+ })
+ }
+
+ fn aliases(&self) -> &[String] {
+ &self.aliases
+ }
+
+ fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
+ // groups accumulator only supports `COUNT(c1)`, not
+ // `COUNT(c1, c2)`, etc
+ if args.is_distinct {
+ return false;
+ }
+ args.args_num == 1
+ }
+
+ fn create_groups_accumulator(
+ &self,
+ _args: AccumulatorArgs,
+ ) -> Result<Box<dyn GroupsAccumulator>> {
+ // instantiate specialized accumulator
+ Ok(Box::new(CountGroupsAccumulator::new()))
+ }
+
+ fn reverse_expr(&self) -> ReversedUDAF {
+ ReversedUDAF::Identical
+ }
+}
+
+#[derive(Debug)]
+struct CountAccumulator {
+ count: i64,
+}
+
+impl CountAccumulator {
+ /// new count accumulator
+ pub fn new() -> Self {
+ Self { count: 0 }
+ }
+}
+
+impl Accumulator for CountAccumulator {
+ fn state(&mut self) -> Result<Vec<ScalarValue>> {
+ Ok(vec![ScalarValue::Int64(Some(self.count))])
+ }
+
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ let array = &values[0];
+ self.count += (array.len() - null_count_for_multiple_cols(values)) as
i64;
+ Ok(())
+ }
+
+ fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ let array = &values[0];
+ self.count -= (array.len() - null_count_for_multiple_cols(values)) as
i64;
+ Ok(())
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ let counts = downcast_value!(states[0], Int64Array);
+ let delta = &arrow::compute::sum(counts);
+ if let Some(d) = delta {
+ self.count += *d;
+ }
+ Ok(())
+ }
+
+ fn evaluate(&mut self) -> Result<ScalarValue> {
+ Ok(ScalarValue::Int64(Some(self.count)))
+ }
+
+ fn supports_retract_batch(&self) -> bool {
+ true
+ }
+
+ fn size(&self) -> usize {
+ std::mem::size_of_val(self)
+ }
+}
+
+/// An accumulator to compute the counts of [`PrimitiveArray<T>`].
+/// Stores values as native types, and does overflow checking
+///
+/// Unlike most other accumulators, COUNT never produces NULLs. If no
+/// non-null values are seen in any group the output is 0. Thus, this
+/// accumulator has no additional null or seen filter tracking.
+#[derive(Debug)]
+struct CountGroupsAccumulator {
+ /// Count per group.
+ ///
+ /// Note this is an i64 and not a u64 (or usize) because the
+ /// output type of count is `DataType::Int64`. Thus by using `i64`
+ /// for the counts, the output [`Int64Array`] can be created
+ /// without copy.
+ counts: Vec<i64>,
+}
+
+impl CountGroupsAccumulator {
+ pub fn new() -> Self {
+ Self { counts: vec![] }
+ }
+}
+
+impl GroupsAccumulator for CountGroupsAccumulator {
+ fn update_batch(
+ &mut self,
+ values: &[ArrayRef],
+ group_indices: &[usize],
+ opt_filter: Option<&BooleanArray>,
+ total_num_groups: usize,
+ ) -> Result<()> {
+ assert_eq!(values.len(), 1, "single argument to update_batch");
+ let values = &values[0];
+
+ // Add one to each group's counter for each non null, non
+ // filtered value
+ self.counts.resize(total_num_groups, 0);
+ accumulate_indices(
+ group_indices,
+ values.logical_nulls().as_ref(),
+ opt_filter,
+ |group_index| {
+ self.counts[group_index] += 1;
+ },
+ );
+
+ Ok(())
+ }
+
+ fn merge_batch(
+ &mut self,
+ values: &[ArrayRef],
+ group_indices: &[usize],
+ opt_filter: Option<&BooleanArray>,
+ total_num_groups: usize,
+ ) -> Result<()> {
+ assert_eq!(values.len(), 1, "one argument to merge_batch");
+ // first batch is counts, second is partial sums
+ let partial_counts = values[0].as_primitive::<Int64Type>();
+
+ // intermediate counts are always created as non null
+ assert_eq!(partial_counts.null_count(), 0);
+ let partial_counts = partial_counts.values();
+
+ // Adds the counts with the partial counts
+ self.counts.resize(total_num_groups, 0);
+ match opt_filter {
+ Some(filter) => filter
+ .iter()
+ .zip(group_indices.iter())
+ .zip(partial_counts.iter())
+ .for_each(|((filter_value, &group_index), partial_count)| {
+ if let Some(true) = filter_value {
+ self.counts[group_index] += partial_count;
+ }
+ }),
+ None => group_indices.iter().zip(partial_counts.iter()).for_each(
+ |(&group_index, partial_count)| {
+ self.counts[group_index] += partial_count;
+ },
+ ),
+ }
+
+ Ok(())
+ }
+
+ fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
+ let counts = emit_to.take_needed(&mut self.counts);
+
+ // Count is always non null (null inputs just don't contribute to the
overall values)
+ let nulls = None;
+ let array = PrimitiveArray::<Int64Type>::new(counts.into(), nulls);
+
+ Ok(Arc::new(array))
+ }
+
+ // return arrays for counts
+ fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
+ let counts = emit_to.take_needed(&mut self.counts);
+ let counts: PrimitiveArray<Int64Type> = Int64Array::from(counts); //
zero copy, no nulls
+ Ok(vec![Arc::new(counts) as ArrayRef])
+ }
+
+ fn size(&self) -> usize {
+ self.counts.capacity() * std::mem::size_of::<usize>()
+ }
+}
+
+/// count null values for multiple columns
+/// for each row if one column value is null, then null_count + 1
+fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize {
+ if values.len() > 1 {
+ let result_bool_buf: Option<BooleanBuffer> = values
+ .iter()
+ .map(|a| a.logical_nulls())
+ .fold(None, |acc, b| match (acc, b) {
+ (Some(acc), Some(b)) => Some(acc.bitand(b.inner())),
+ (Some(acc), None) => Some(acc),
+ (None, Some(b)) => Some(b.into_inner()),
+ _ => None,
+ });
+ result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits())
+ } else {
+ values[0]
+ .logical_nulls()
+ .map_or(0, |nulls| nulls.null_count())
+ }
+}
+
+/// General purpose distinct accumulator that works for any DataType by using
+/// [`ScalarValue`].
+///
+/// It stores intermediate results as a `ListArray`
+///
+/// Note that many types have specialized accumulators that are (much)
+/// more efficient such as [`PrimitiveDistinctCountAccumulator`] and
+/// [`BytesDistinctCountAccumulator`]
+#[derive(Debug)]
+struct DistinctCountAccumulator {
+ values: HashSet<ScalarValue, RandomState>,
+ state_data_type: DataType,
+}
+
+impl DistinctCountAccumulator {
+ // calculating the size for fixed length values, taking first batch size *
+ // number of batches This method is faster than .full_size(), however it is
+ // not suitable for variable length values like strings or complex types
+ fn fixed_size(&self) -> usize {
+ std::mem::size_of_val(self)
+ + (std::mem::size_of::<ScalarValue>() * self.values.capacity())
+ + self
+ .values
+ .iter()
+ .next()
+ .map(|vals| ScalarValue::size(vals) -
std::mem::size_of_val(vals))
+ .unwrap_or(0)
+ + std::mem::size_of::<DataType>()
+ }
+
+ // calculates the size as accurately as possible. Note that calling this
+ // method is expensive
+ fn full_size(&self) -> usize {
+ std::mem::size_of_val(self)
+ + (std::mem::size_of::<ScalarValue>() * self.values.capacity())
+ + self
+ .values
+ .iter()
+ .map(|vals| ScalarValue::size(vals) -
std::mem::size_of_val(vals))
+ .sum::<usize>()
+ + std::mem::size_of::<DataType>()
+ }
+}
+
+impl Accumulator for DistinctCountAccumulator {
+ /// Returns the distinct values seen so far as (one element) ListArray.
+ fn state(&mut self) -> Result<Vec<ScalarValue>> {
+ let scalars = self.values.iter().cloned().collect::<Vec<_>>();
+ let arr = ScalarValue::new_list(scalars.as_slice(),
&self.state_data_type);
+ Ok(vec![ScalarValue::List(arr)])
+ }
+
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ if values.is_empty() {
+ return Ok(());
+ }
+
+ let arr = &values[0];
+ if arr.data_type() == &DataType::Null {
+ return Ok(());
+ }
+
+ (0..arr.len()).try_for_each(|index| {
+ if !arr.is_null(index) {
+ let scalar = ScalarValue::try_from_array(arr, index)?;
+ self.values.insert(scalar);
+ }
+ Ok(())
+ })
+ }
+
+ /// Merges multiple sets of distinct values into the current set.
+ ///
+ /// The input to this function is a `ListArray` with **multiple** rows,
+ /// where each row contains the values from a partial aggregate's phase
(e.g.
+ /// the result of calling `Self::state` on multiple accumulators).
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ if states.is_empty() {
+ return Ok(());
+ }
+ assert_eq!(states.len(), 1, "array_agg states must be singleton!");
+ let array = &states[0];
+ let list_array = array.as_list::<i32>();
+ for inner_array in list_array.iter() {
+ let Some(inner_array) = inner_array else {
+ return internal_err!(
+ "Intermediate results of COUNT DISTINCT should always be
non null"
+ );
+ };
+ self.update_batch(&[inner_array])?;
+ }
+ Ok(())
+ }
+
+ fn evaluate(&mut self) -> Result<ScalarValue> {
+ Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
+ }
+
+ fn size(&self) -> usize {
+ match &self.state_data_type {
+ DataType::Boolean | DataType::Null => self.fixed_size(),
+ d if d.is_primitive() => self.fixed_size(),
+ _ => self.full_size(),
+ }
+ }
+}
diff --git a/datafusion/functions-aggregate/src/lib.rs
b/datafusion/functions-aggregate/src/lib.rs
index 2d062cf2cb..56fc1305bb 100644
--- a/datafusion/functions-aggregate/src/lib.rs
+++ b/datafusion/functions-aggregate/src/lib.rs
@@ -56,6 +56,7 @@
pub mod macros;
pub mod approx_distinct;
+pub mod count;
pub mod covariance;
pub mod first_last;
pub mod hyperloglog;
@@ -77,6 +78,8 @@ use std::sync::Arc;
pub mod expr_fn {
pub use super::approx_distinct;
pub use super::approx_median::approx_median;
+ pub use super::count::count;
+ pub use super::count::count_distinct;
pub use super::covariance::covar_pop;
pub use super::covariance::covar_samp;
pub use super::first_last::first_value;
@@ -98,6 +101,7 @@ pub fn all_default_aggregate_functions() ->
Vec<Arc<AggregateUDF>> {
sum::sum_udaf(),
covariance::covar_pop_udaf(),
median::median_udaf(),
+ count::count_udaf(),
variance::var_samp_udaf(),
variance::var_pop_udaf(),
stddev::stddev_udaf(),
@@ -133,8 +137,8 @@ mod tests {
let mut names = HashSet::new();
for func in all_default_aggregate_functions() {
// TODO: remove this
- // sum is in intermidiate migration state, skip this
- if func.name().to_lowercase() == "sum" {
+ // These functions are in intermidiate migration state, skip them
+ if func.name().to_lowercase() == "count" {
continue;
}
assert!(
diff --git a/datafusion/optimizer/src/decorrelate.rs
b/datafusion/optimizer/src/decorrelate.rs
index b55b1a7f8f..e14ee763a3 100644
--- a/datafusion/optimizer/src/decorrelate.rs
+++ b/datafusion/optimizer/src/decorrelate.rs
@@ -441,8 +441,14 @@ fn agg_exprs_evaluation_result_on_empty_batch(
Transformed::yes(Expr::Literal(ScalarValue::Null))
}
}
- AggregateFunctionDefinition::UDF { .. } => {
- Transformed::yes(Expr::Literal(ScalarValue::Null))
+ AggregateFunctionDefinition::UDF(fun) => {
+ if fun.name() == "COUNT" {
+
Transformed::yes(Expr::Literal(ScalarValue::Int64(Some(
+ 0,
+ ))))
+ } else {
+
Transformed::yes(Expr::Literal(ScalarValue::Null))
+ }
}
},
_ => Transformed::no(expr),
diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs
b/datafusion/optimizer/src/single_distinct_to_groupby.rs
index 32b6703bca..e738209eb4 100644
--- a/datafusion/optimizer/src/single_distinct_to_groupby.rs
+++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs
@@ -361,8 +361,7 @@ impl OptimizerRule for SingleDistinctToGroupBy {
mod tests {
use super::*;
use crate::test::*;
- use datafusion_expr::expr;
- use datafusion_expr::expr::GroupingSet;
+ use datafusion_expr::expr::{self, GroupingSet};
use datafusion_expr::test::function_stub::{sum, sum_udaf};
use datafusion_expr::{
count, count_distinct, lit, logical_plan::builder::LogicalPlanBuilder,
max, min,
diff --git a/datafusion/physical-expr-common/Cargo.toml
b/datafusion/physical-expr-common/Cargo.toml
index 637b877511..3ef2d53455 100644
--- a/datafusion/physical-expr-common/Cargo.toml
+++ b/datafusion/physical-expr-common/Cargo.toml
@@ -36,7 +36,9 @@ name = "datafusion_physical_expr_common"
path = "src/lib.rs"
[dependencies]
+ahash = { workspace = true }
arrow = { workspace = true }
datafusion-common = { workspace = true, default-features = true }
datafusion-expr = { workspace = true }
+hashbrown = { workspace = true }
rand = { workspace = true }
diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/bytes.rs
b/datafusion/physical-expr-common/src/aggregate/count_distinct/bytes.rs
similarity index 93%
rename from datafusion/physical-expr/src/aggregate/count_distinct/bytes.rs
rename to datafusion/physical-expr-common/src/aggregate/count_distinct/bytes.rs
index 2ed9b002c8..5c888ca66c 100644
--- a/datafusion/physical-expr/src/aggregate/count_distinct/bytes.rs
+++ b/datafusion/physical-expr-common/src/aggregate/count_distinct/bytes.rs
@@ -18,7 +18,7 @@
//! [`BytesDistinctCountAccumulator`] for Utf8/LargeUtf8/Binary/LargeBinary
values
use crate::binary_map::{ArrowBytesSet, OutputType};
-use arrow_array::{ArrayRef, OffsetSizeTrait};
+use arrow::array::{ArrayRef, OffsetSizeTrait};
use datafusion_common::cast::as_list_array;
use datafusion_common::utils::array_into_list_array;
use datafusion_common::ScalarValue;
@@ -35,10 +35,10 @@ use std::sync::Arc;
/// [`BinaryArray`]: arrow::array::BinaryArray
/// [`LargeBinaryArray`]: arrow::array::LargeBinaryArray
#[derive(Debug)]
-pub(super) struct BytesDistinctCountAccumulator<O:
OffsetSizeTrait>(ArrowBytesSet<O>);
+pub struct BytesDistinctCountAccumulator<O: OffsetSizeTrait>(ArrowBytesSet<O>);
impl<O: OffsetSizeTrait> BytesDistinctCountAccumulator<O> {
- pub(super) fn new(output_type: OutputType) -> Self {
+ pub fn new(output_type: OutputType) -> Self {
Self(ArrowBytesSet::new(output_type))
}
}
diff --git a/datafusion/physical-expr-common/src/lib.rs
b/datafusion/physical-expr-common/src/aggregate/count_distinct/mod.rs
similarity index 82%
copy from datafusion/physical-expr-common/src/lib.rs
copy to datafusion/physical-expr-common/src/aggregate/count_distinct/mod.rs
index f335958698..f216406d0d 100644
--- a/datafusion/physical-expr-common/src/lib.rs
+++ b/datafusion/physical-expr-common/src/aggregate/count_distinct/mod.rs
@@ -15,9 +15,9 @@
// specific language governing permissions and limitations
// under the License.
-pub mod aggregate;
-pub mod expressions;
-pub mod physical_expr;
-pub mod sort_expr;
-pub mod tree_node;
-pub mod utils;
+mod bytes;
+mod native;
+
+pub use bytes::BytesDistinctCountAccumulator;
+pub use native::FloatDistinctCountAccumulator;
+pub use native::PrimitiveDistinctCountAccumulator;
diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/native.rs
b/datafusion/physical-expr-common/src/aggregate/count_distinct/native.rs
similarity index 93%
rename from datafusion/physical-expr/src/aggregate/count_distinct/native.rs
rename to datafusion/physical-expr-common/src/aggregate/count_distinct/native.rs
index 0e7483d4a1..72b83676e8 100644
--- a/datafusion/physical-expr/src/aggregate/count_distinct/native.rs
+++ b/datafusion/physical-expr-common/src/aggregate/count_distinct/native.rs
@@ -26,10 +26,10 @@ use std::hash::Hash;
use std::sync::Arc;
use ahash::RandomState;
+use arrow::array::types::ArrowPrimitiveType;
use arrow::array::ArrayRef;
-use arrow_array::types::ArrowPrimitiveType;
-use arrow_array::PrimitiveArray;
-use arrow_schema::DataType;
+use arrow::array::PrimitiveArray;
+use arrow::datatypes::DataType;
use datafusion_common::cast::{as_list_array, as_primitive_array};
use datafusion_common::utils::array_into_list_array;
@@ -40,7 +40,7 @@ use datafusion_expr::Accumulator;
use crate::aggregate::utils::Hashable;
#[derive(Debug)]
-pub(super) struct PrimitiveDistinctCountAccumulator<T>
+pub struct PrimitiveDistinctCountAccumulator<T>
where
T: ArrowPrimitiveType + Send,
T::Native: Eq + Hash,
@@ -54,7 +54,7 @@ where
T: ArrowPrimitiveType + Send,
T::Native: Eq + Hash,
{
- pub(super) fn new(data_type: &DataType) -> Self {
+ pub fn new(data_type: &DataType) -> Self {
Self {
values: HashSet::default(),
data_type: data_type.clone(),
@@ -125,7 +125,7 @@ where
}
#[derive(Debug)]
-pub(super) struct FloatDistinctCountAccumulator<T>
+pub struct FloatDistinctCountAccumulator<T>
where
T: ArrowPrimitiveType + Send,
{
@@ -136,13 +136,22 @@ impl<T> FloatDistinctCountAccumulator<T>
where
T: ArrowPrimitiveType + Send,
{
- pub(super) fn new() -> Self {
+ pub fn new() -> Self {
Self {
values: HashSet::default(),
}
}
}
+impl<T> Default for FloatDistinctCountAccumulator<T>
+where
+ T: ArrowPrimitiveType + Send,
+{
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
impl<T> Accumulator for FloatDistinctCountAccumulator<T>
where
T: ArrowPrimitiveType + Send + Debug,
diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs
b/datafusion/physical-expr-common/src/aggregate/mod.rs
index ec02df57b8..21884f840d 100644
--- a/datafusion/physical-expr-common/src/aggregate/mod.rs
+++ b/datafusion/physical-expr-common/src/aggregate/mod.rs
@@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.
+pub mod count_distinct;
pub mod groups_accumulator;
pub mod stats;
pub mod tdigest;
diff --git a/datafusion/physical-expr/src/binary_map.rs
b/datafusion/physical-expr-common/src/binary_map.rs
similarity index 98%
rename from datafusion/physical-expr/src/binary_map.rs
rename to datafusion/physical-expr-common/src/binary_map.rs
index 0923fcdaeb..6d5ba737a1 100644
--- a/datafusion/physical-expr/src/binary_map.rs
+++ b/datafusion/physical-expr-common/src/binary_map.rs
@@ -19,17 +19,16 @@
//! StringArray / LargeStringArray / BinaryArray / LargeBinaryArray.
use ahash::RandomState;
-use arrow_array::cast::AsArray;
-use arrow_array::types::{ByteArrayType, GenericBinaryType, GenericStringType};
-use arrow_array::{
- Array, ArrayRef, GenericBinaryArray, GenericStringArray, OffsetSizeTrait,
+use arrow::array::cast::AsArray;
+use arrow::array::types::{ByteArrayType, GenericBinaryType, GenericStringType};
+use arrow::array::{
+ Array, ArrayRef, BooleanBufferBuilder, BufferBuilder, GenericBinaryArray,
+ GenericStringArray, OffsetSizeTrait,
};
-use arrow_buffer::{
- BooleanBufferBuilder, BufferBuilder, NullBuffer, OffsetBuffer,
ScalarBuffer,
-};
-use arrow_schema::DataType;
+use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer};
+use arrow::datatypes::DataType;
use datafusion_common::hash_utils::create_hashes;
-use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt};
+use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt};
use std::any::type_name;
use std::fmt::Debug;
use std::mem;
@@ -605,8 +604,8 @@ where
#[cfg(test)]
mod tests {
use super::*;
- use arrow_array::{BinaryArray, LargeBinaryArray, StringArray};
- use hashbrown::HashMap;
+ use arrow::array::{BinaryArray, LargeBinaryArray, StringArray};
+ use std::collections::HashMap;
#[test]
fn string_set_empty() {
diff --git a/datafusion/physical-expr-common/src/lib.rs
b/datafusion/physical-expr-common/src/lib.rs
index f335958698..0ddb84141a 100644
--- a/datafusion/physical-expr-common/src/lib.rs
+++ b/datafusion/physical-expr-common/src/lib.rs
@@ -16,6 +16,7 @@
// under the License.
pub mod aggregate;
+pub mod binary_map;
pub mod expressions;
pub mod physical_expr;
pub mod sort_expr;
diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs
b/datafusion/physical-expr/src/aggregate/build_in.rs
index ac24dd2e76..aee7bca3b8 100644
--- a/datafusion/physical-expr/src/aggregate/build_in.rs
+++ b/datafusion/physical-expr/src/aggregate/build_in.rs
@@ -30,12 +30,13 @@ use std::sync::Arc;
use arrow::datatypes::Schema;
+use datafusion_common::{exec_err, internal_err, not_impl_err, Result};
+use datafusion_expr::AggregateFunction;
+
use crate::aggregate::average::Avg;
use crate::aggregate::regr::RegrType;
use crate::expressions::{self, Literal};
use crate::{AggregateExpr, PhysicalExpr, PhysicalSortExpr};
-use datafusion_common::{exec_err, not_impl_err, Result};
-use datafusion_expr::AggregateFunction;
/// Create a physical aggregation expression.
/// This function errors when `input_phy_exprs`' can't be coerced to a valid
argument type of the aggregation function.
pub fn create_aggregate_expr(
@@ -60,14 +61,9 @@ pub fn create_aggregate_expr(
.collect::<Result<Vec<_>>>()?;
let input_phy_exprs = input_phy_exprs.to_vec();
Ok(match (fun, distinct) {
- (AggregateFunction::Count, false) => Arc::new(
- expressions::Count::new_with_multiple_exprs(input_phy_exprs, name,
data_type),
- ),
- (AggregateFunction::Count, true) =>
Arc::new(expressions::DistinctCount::new(
- data_type,
- input_phy_exprs[0].clone(),
- name,
- )),
+ (AggregateFunction::Count, _) => {
+ return internal_err!("Builtin Count will be removed");
+ }
(AggregateFunction::Grouping, _) =>
Arc::new(expressions::Grouping::new(
input_phy_exprs[0].clone(),
name,
@@ -320,7 +316,7 @@ mod tests {
use super::*;
use crate::expressions::{
try_cast, ApproxPercentileCont, ArrayAgg, Avg, BitAnd, BitOr, BitXor,
BoolAnd,
- BoolOr, Count, DistinctArrayAgg, DistinctCount, Max, Min,
+ BoolOr, DistinctArrayAgg, Max, Min,
};
use datafusion_common::{plan_err, DataFusionError, ScalarValue};
@@ -328,8 +324,8 @@ mod tests {
use datafusion_expr::{type_coercion, Signature};
#[test]
- fn test_count_arragg_approx_expr() -> Result<()> {
- let funcs = vec![AggregateFunction::Count,
AggregateFunction::ArrayAgg];
+ fn test_approx_expr() -> Result<()> {
+ let funcs = vec![AggregateFunction::ArrayAgg];
let data_types = vec![
DataType::UInt32,
DataType::Int32,
@@ -352,29 +348,18 @@ mod tests {
&input_schema,
"c1",
)?;
- match fun {
- AggregateFunction::Count => {
- assert!(result_agg_phy_exprs.as_any().is::<Count>());
- assert_eq!("c1", result_agg_phy_exprs.name());
- assert_eq!(
- Field::new("c1", DataType::Int64, true),
- result_agg_phy_exprs.field().unwrap()
- );
- }
- AggregateFunction::ArrayAgg => {
-
assert!(result_agg_phy_exprs.as_any().is::<ArrayAgg>());
- assert_eq!("c1", result_agg_phy_exprs.name());
- assert_eq!(
- Field::new_list(
- "c1",
- Field::new("item", data_type.clone(), true),
- true,
- ),
- result_agg_phy_exprs.field().unwrap()
- );
- }
- _ => {}
- };
+ if fun == AggregateFunction::ArrayAgg {
+ assert!(result_agg_phy_exprs.as_any().is::<ArrayAgg>());
+ assert_eq!("c1", result_agg_phy_exprs.name());
+ assert_eq!(
+ Field::new_list(
+ "c1",
+ Field::new("item", data_type.clone(), true),
+ true,
+ ),
+ result_agg_phy_exprs.field().unwrap()
+ );
+ }
let result_distinct = create_physical_agg_expr_for_test(
&fun,
@@ -383,29 +368,18 @@ mod tests {
&input_schema,
"c1",
)?;
- match fun {
- AggregateFunction::Count => {
-
assert!(result_distinct.as_any().is::<DistinctCount>());
- assert_eq!("c1", result_distinct.name());
- assert_eq!(
- Field::new("c1", DataType::Int64, true),
- result_distinct.field().unwrap()
- );
- }
- AggregateFunction::ArrayAgg => {
-
assert!(result_distinct.as_any().is::<DistinctArrayAgg>());
- assert_eq!("c1", result_distinct.name());
- assert_eq!(
- Field::new_list(
- "c1",
- Field::new("item", data_type.clone(), true),
- true,
- ),
- result_agg_phy_exprs.field().unwrap()
- );
- }
- _ => {}
- };
+ if fun == AggregateFunction::ArrayAgg {
+ assert!(result_distinct.as_any().is::<DistinctArrayAgg>());
+ assert_eq!("c1", result_distinct.name());
+ assert_eq!(
+ Field::new_list(
+ "c1",
+ Field::new("item", data_type.clone(), true),
+ true,
+ ),
+ result_agg_phy_exprs.field().unwrap()
+ );
+ }
}
}
Ok(())
diff --git a/datafusion/physical-expr/src/aggregate/count.rs
b/datafusion/physical-expr/src/aggregate/count.rs
deleted file mode 100644
index aad18a82ab..0000000000
--- a/datafusion/physical-expr/src/aggregate/count.rs
+++ /dev/null
@@ -1,348 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-//! Defines physical expressions that can evaluated at runtime during query
execution
-
-use std::any::Any;
-use std::fmt::Debug;
-use std::ops::BitAnd;
-use std::sync::Arc;
-
-use crate::aggregate::utils::down_cast_any_ref;
-use crate::{AggregateExpr, PhysicalExpr};
-use arrow::array::{Array, Int64Array};
-use arrow::compute;
-use arrow::datatypes::DataType;
-use arrow::{array::ArrayRef, datatypes::Field};
-use arrow_array::cast::AsArray;
-use arrow_array::types::Int64Type;
-use arrow_array::PrimitiveArray;
-use arrow_buffer::BooleanBuffer;
-use datafusion_common::{downcast_value, ScalarValue};
-use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::{Accumulator, EmitTo, GroupsAccumulator};
-
-use crate::expressions::format_state_name;
-
-use super::groups_accumulator::accumulate::accumulate_indices;
-
-/// COUNT aggregate expression
-/// Returns the amount of non-null values of the given expression.
-#[derive(Debug, Clone)]
-pub struct Count {
- name: String,
- data_type: DataType,
- nullable: bool,
- /// Input exprs
- ///
- /// For `COUNT(c1)` this is `[c1]`
- /// For `COUNT(c1, c2)` this is `[c1, c2]`
- exprs: Vec<Arc<dyn PhysicalExpr>>,
-}
-
-impl Count {
- /// Create a new COUNT aggregate function.
- pub fn new(
- expr: Arc<dyn PhysicalExpr>,
- name: impl Into<String>,
- data_type: DataType,
- ) -> Self {
- Self {
- name: name.into(),
- exprs: vec![expr],
- data_type,
- nullable: true,
- }
- }
-
- pub fn new_with_multiple_exprs(
- exprs: Vec<Arc<dyn PhysicalExpr>>,
- name: impl Into<String>,
- data_type: DataType,
- ) -> Self {
- Self {
- name: name.into(),
- exprs,
- data_type,
- nullable: true,
- }
- }
-}
-
-/// An accumulator to compute the counts of [`PrimitiveArray<T>`].
-/// Stores values as native types, and does overflow checking
-///
-/// Unlike most other accumulators, COUNT never produces NULLs. If no
-/// non-null values are seen in any group the output is 0. Thus, this
-/// accumulator has no additional null or seen filter tracking.
-#[derive(Debug)]
-struct CountGroupsAccumulator {
- /// Count per group.
- ///
- /// Note this is an i64 and not a u64 (or usize) because the
- /// output type of count is `DataType::Int64`. Thus by using `i64`
- /// for the counts, the output [`Int64Array`] can be created
- /// without copy.
- counts: Vec<i64>,
-}
-
-impl CountGroupsAccumulator {
- pub fn new() -> Self {
- Self { counts: vec![] }
- }
-}
-
-impl GroupsAccumulator for CountGroupsAccumulator {
- fn update_batch(
- &mut self,
- values: &[ArrayRef],
- group_indices: &[usize],
- opt_filter: Option<&arrow_array::BooleanArray>,
- total_num_groups: usize,
- ) -> Result<()> {
- assert_eq!(values.len(), 1, "single argument to update_batch");
- let values = &values[0];
-
- // Add one to each group's counter for each non null, non
- // filtered value
- self.counts.resize(total_num_groups, 0);
- accumulate_indices(
- group_indices,
- values.logical_nulls().as_ref(),
- opt_filter,
- |group_index| {
- self.counts[group_index] += 1;
- },
- );
-
- Ok(())
- }
-
- fn merge_batch(
- &mut self,
- values: &[ArrayRef],
- group_indices: &[usize],
- opt_filter: Option<&arrow_array::BooleanArray>,
- total_num_groups: usize,
- ) -> Result<()> {
- assert_eq!(values.len(), 1, "one argument to merge_batch");
- // first batch is counts, second is partial sums
- let partial_counts = values[0].as_primitive::<Int64Type>();
-
- // intermediate counts are always created as non null
- assert_eq!(partial_counts.null_count(), 0);
- let partial_counts = partial_counts.values();
-
- // Adds the counts with the partial counts
- self.counts.resize(total_num_groups, 0);
- match opt_filter {
- Some(filter) => filter
- .iter()
- .zip(group_indices.iter())
- .zip(partial_counts.iter())
- .for_each(|((filter_value, &group_index), partial_count)| {
- if let Some(true) = filter_value {
- self.counts[group_index] += partial_count;
- }
- }),
- None => group_indices.iter().zip(partial_counts.iter()).for_each(
- |(&group_index, partial_count)| {
- self.counts[group_index] += partial_count;
- },
- ),
- }
-
- Ok(())
- }
-
- fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
- let counts = emit_to.take_needed(&mut self.counts);
-
- // Count is always non null (null inputs just don't contribute to the
overall values)
- let nulls = None;
- let array = PrimitiveArray::<Int64Type>::new(counts.into(), nulls);
-
- Ok(Arc::new(array))
- }
-
- // return arrays for counts
- fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
- let counts = emit_to.take_needed(&mut self.counts);
- let counts: PrimitiveArray<Int64Type> = Int64Array::from(counts); //
zero copy, no nulls
- Ok(vec![Arc::new(counts) as ArrayRef])
- }
-
- fn size(&self) -> usize {
- self.counts.capacity() * std::mem::size_of::<usize>()
- }
-}
-
-/// count null values for multiple columns
-/// for each row if one column value is null, then null_count + 1
-fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize {
- if values.len() > 1 {
- let result_bool_buf: Option<BooleanBuffer> = values
- .iter()
- .map(|a| a.logical_nulls())
- .fold(None, |acc, b| match (acc, b) {
- (Some(acc), Some(b)) => Some(acc.bitand(b.inner())),
- (Some(acc), None) => Some(acc),
- (None, Some(b)) => Some(b.into_inner()),
- _ => None,
- });
- result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits())
- } else {
- values[0]
- .logical_nulls()
- .map_or(0, |nulls| nulls.null_count())
- }
-}
-
-impl AggregateExpr for Count {
- /// Return a reference to Any that can be used for downcasting
- fn as_any(&self) -> &dyn Any {
- self
- }
-
- fn field(&self) -> Result<Field> {
- Ok(Field::new(&self.name, DataType::Int64, self.nullable))
- }
-
- fn state_fields(&self) -> Result<Vec<Field>> {
- Ok(vec![Field::new(
- format_state_name(&self.name, "count"),
- DataType::Int64,
- true,
- )])
- }
-
- fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
- self.exprs.clone()
- }
-
- fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
- Ok(Box::new(CountAccumulator::new()))
- }
-
- fn name(&self) -> &str {
- &self.name
- }
-
- fn groups_accumulator_supported(&self) -> bool {
- // groups accumulator only supports `COUNT(c1)`, not
- // `COUNT(c1, c2)`, etc
- self.exprs.len() == 1
- }
-
- fn reverse_expr(&self) -> Option<Arc<dyn AggregateExpr>> {
- Some(Arc::new(self.clone()))
- }
-
- fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
- Ok(Box::new(CountAccumulator::new()))
- }
-
- fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
- // instantiate specialized accumulator
- Ok(Box::new(CountGroupsAccumulator::new()))
- }
-
- fn with_new_expressions(
- &self,
- args: Vec<Arc<dyn PhysicalExpr>>,
- order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
- ) -> Option<Arc<dyn AggregateExpr>> {
- debug_assert_eq!(self.exprs.len(), args.len());
- debug_assert!(order_by_exprs.is_empty());
- Some(Arc::new(Count {
- name: self.name.clone(),
- data_type: self.data_type.clone(),
- nullable: self.nullable,
- exprs: args,
- }))
- }
-}
-
-impl PartialEq<dyn Any> for Count {
- fn eq(&self, other: &dyn Any) -> bool {
- down_cast_any_ref(other)
- .downcast_ref::<Self>()
- .map(|x| {
- self.name == x.name
- && self.data_type == x.data_type
- && self.nullable == x.nullable
- && self.exprs.len() == x.exprs.len()
- && self
- .exprs
- .iter()
- .zip(x.exprs.iter())
- .all(|(expr1, expr2)| expr1.eq(expr2))
- })
- .unwrap_or(false)
- }
-}
-
-#[derive(Debug)]
-struct CountAccumulator {
- count: i64,
-}
-
-impl CountAccumulator {
- /// new count accumulator
- pub fn new() -> Self {
- Self { count: 0 }
- }
-}
-
-impl Accumulator for CountAccumulator {
- fn state(&mut self) -> Result<Vec<ScalarValue>> {
- Ok(vec![ScalarValue::Int64(Some(self.count))])
- }
-
- fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
- let array = &values[0];
- self.count += (array.len() - null_count_for_multiple_cols(values)) as
i64;
- Ok(())
- }
-
- fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
- let array = &values[0];
- self.count -= (array.len() - null_count_for_multiple_cols(values)) as
i64;
- Ok(())
- }
-
- fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
- let counts = downcast_value!(states[0], Int64Array);
- let delta = &compute::sum(counts);
- if let Some(d) = delta {
- self.count += *d;
- }
- Ok(())
- }
-
- fn evaluate(&mut self) -> Result<ScalarValue> {
- Ok(ScalarValue::Int64(Some(self.count)))
- }
-
- fn supports_retract_batch(&self) -> bool {
- true
- }
-
- fn size(&self) -> usize {
- std::mem::size_of_val(self)
- }
-}
diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
deleted file mode 100644
index 52f1c5c0f9..0000000000
--- a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
+++ /dev/null
@@ -1,718 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-mod bytes;
-mod native;
-
-use std::any::Any;
-use std::collections::HashSet;
-use std::fmt::Debug;
-use std::sync::Arc;
-
-use ahash::RandomState;
-use arrow::array::{Array, ArrayRef};
-use arrow::datatypes::{DataType, Field, TimeUnit};
-use arrow_array::cast::AsArray;
-use arrow_array::types::{
- Date32Type, Date64Type, Decimal128Type, Decimal256Type, Float16Type,
Float32Type,
- Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
Time32MillisecondType,
- Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
- TimestampMicrosecondType, TimestampMillisecondType,
TimestampNanosecondType,
- TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
-};
-
-use datafusion_common::{internal_err, Result, ScalarValue};
-use datafusion_expr::Accumulator;
-
-use crate::aggregate::count_distinct::bytes::BytesDistinctCountAccumulator;
-use crate::aggregate::count_distinct::native::{
- FloatDistinctCountAccumulator, PrimitiveDistinctCountAccumulator,
-};
-use crate::aggregate::utils::down_cast_any_ref;
-use crate::binary_map::OutputType;
-use crate::expressions::format_state_name;
-use crate::{AggregateExpr, PhysicalExpr};
-
-/// Expression for a `COUNT(DISTINCT)` aggregation.
-#[derive(Debug)]
-pub struct DistinctCount {
- /// Column name
- name: String,
- /// The DataType used to hold the state for each input
- state_data_type: DataType,
- /// The input arguments
- expr: Arc<dyn PhysicalExpr>,
-}
-
-impl DistinctCount {
- /// Create a new COUNT(DISTINCT) aggregate function.
- pub fn new(
- input_data_type: DataType,
- expr: Arc<dyn PhysicalExpr>,
- name: impl Into<String>,
- ) -> Self {
- Self {
- name: name.into(),
- state_data_type: input_data_type,
- expr,
- }
- }
-}
-
-impl AggregateExpr for DistinctCount {
- /// Return a reference to Any that can be used for downcasting
- fn as_any(&self) -> &dyn Any {
- self
- }
-
- fn field(&self) -> Result<Field> {
- Ok(Field::new(&self.name, DataType::Int64, true))
- }
-
- fn state_fields(&self) -> Result<Vec<Field>> {
- Ok(vec![Field::new_list(
- format_state_name(&self.name, "count distinct"),
- Field::new("item", self.state_data_type.clone(), true),
- false,
- )])
- }
-
- fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
- vec![self.expr.clone()]
- }
-
- fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
- use DataType::*;
- use TimeUnit::*;
-
- let data_type = &self.state_data_type;
- Ok(match data_type {
- // try and use a specialized accumulator if possible, otherwise
fall back to generic accumulator
- Int8 =>
Box::new(PrimitiveDistinctCountAccumulator::<Int8Type>::new(
- data_type,
- )),
- Int16 =>
Box::new(PrimitiveDistinctCountAccumulator::<Int16Type>::new(
- data_type,
- )),
- Int32 =>
Box::new(PrimitiveDistinctCountAccumulator::<Int32Type>::new(
- data_type,
- )),
- Int64 =>
Box::new(PrimitiveDistinctCountAccumulator::<Int64Type>::new(
- data_type,
- )),
- UInt8 =>
Box::new(PrimitiveDistinctCountAccumulator::<UInt8Type>::new(
- data_type,
- )),
- UInt16 =>
Box::new(PrimitiveDistinctCountAccumulator::<UInt16Type>::new(
- data_type,
- )),
- UInt32 =>
Box::new(PrimitiveDistinctCountAccumulator::<UInt32Type>::new(
- data_type,
- )),
- UInt64 =>
Box::new(PrimitiveDistinctCountAccumulator::<UInt64Type>::new(
- data_type,
- )),
- Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
- Decimal128Type,
- >::new(data_type)),
- Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::<
- Decimal256Type,
- >::new(data_type)),
-
- Date32 =>
Box::new(PrimitiveDistinctCountAccumulator::<Date32Type>::new(
- data_type,
- )),
- Date64 =>
Box::new(PrimitiveDistinctCountAccumulator::<Date64Type>::new(
- data_type,
- )),
- Time32(Millisecond) =>
Box::new(PrimitiveDistinctCountAccumulator::<
- Time32MillisecondType,
- >::new(data_type)),
- Time32(Second) => Box::new(PrimitiveDistinctCountAccumulator::<
- Time32SecondType,
- >::new(data_type)),
- Time64(Microsecond) =>
Box::new(PrimitiveDistinctCountAccumulator::<
- Time64MicrosecondType,
- >::new(data_type)),
- Time64(Nanosecond) => Box::new(PrimitiveDistinctCountAccumulator::<
- Time64NanosecondType,
- >::new(data_type)),
- Timestamp(Microsecond, _) =>
Box::new(PrimitiveDistinctCountAccumulator::<
- TimestampMicrosecondType,
- >::new(data_type)),
- Timestamp(Millisecond, _) =>
Box::new(PrimitiveDistinctCountAccumulator::<
- TimestampMillisecondType,
- >::new(data_type)),
- Timestamp(Nanosecond, _) =>
Box::new(PrimitiveDistinctCountAccumulator::<
- TimestampNanosecondType,
- >::new(data_type)),
- Timestamp(Second, _) =>
Box::new(PrimitiveDistinctCountAccumulator::<
- TimestampSecondType,
- >::new(data_type)),
-
- Float16 =>
Box::new(FloatDistinctCountAccumulator::<Float16Type>::new()),
- Float32 =>
Box::new(FloatDistinctCountAccumulator::<Float32Type>::new()),
- Float64 =>
Box::new(FloatDistinctCountAccumulator::<Float64Type>::new()),
-
- Utf8 =>
Box::new(BytesDistinctCountAccumulator::<i32>::new(OutputType::Utf8)),
- LargeUtf8 => {
-
Box::new(BytesDistinctCountAccumulator::<i64>::new(OutputType::Utf8))
- }
- Binary => Box::new(BytesDistinctCountAccumulator::<i32>::new(
- OutputType::Binary,
- )),
- LargeBinary => Box::new(BytesDistinctCountAccumulator::<i64>::new(
- OutputType::Binary,
- )),
-
- // Use the generic accumulator based on `ScalarValue` for all
other types
- _ => Box::new(DistinctCountAccumulator {
- values: HashSet::default(),
- state_data_type: self.state_data_type.clone(),
- }),
- })
- }
-
- fn name(&self) -> &str {
- &self.name
- }
-}
-
-impl PartialEq<dyn Any> for DistinctCount {
- fn eq(&self, other: &dyn Any) -> bool {
- down_cast_any_ref(other)
- .downcast_ref::<Self>()
- .map(|x| {
- self.name == x.name
- && self.state_data_type == x.state_data_type
- && self.expr.eq(&x.expr)
- })
- .unwrap_or(false)
- }
-}
-
-/// General purpose distinct accumulator that works for any DataType by using
-/// [`ScalarValue`].
-///
-/// It stores intermediate results as a `ListArray`
-///
-/// Note that many types have specialized accumulators that are (much)
-/// more efficient such as [`PrimitiveDistinctCountAccumulator`] and
-/// [`BytesDistinctCountAccumulator`]
-#[derive(Debug)]
-struct DistinctCountAccumulator {
- values: HashSet<ScalarValue, RandomState>,
- state_data_type: DataType,
-}
-
-impl DistinctCountAccumulator {
- // calculating the size for fixed length values, taking first batch size *
- // number of batches This method is faster than .full_size(), however it is
- // not suitable for variable length values like strings or complex types
- fn fixed_size(&self) -> usize {
- std::mem::size_of_val(self)
- + (std::mem::size_of::<ScalarValue>() * self.values.capacity())
- + self
- .values
- .iter()
- .next()
- .map(|vals| ScalarValue::size(vals) -
std::mem::size_of_val(vals))
- .unwrap_or(0)
- + std::mem::size_of::<DataType>()
- }
-
- // calculates the size as accurately as possible. Note that calling this
- // method is expensive
- fn full_size(&self) -> usize {
- std::mem::size_of_val(self)
- + (std::mem::size_of::<ScalarValue>() * self.values.capacity())
- + self
- .values
- .iter()
- .map(|vals| ScalarValue::size(vals) -
std::mem::size_of_val(vals))
- .sum::<usize>()
- + std::mem::size_of::<DataType>()
- }
-}
-
-impl Accumulator for DistinctCountAccumulator {
- /// Returns the distinct values seen so far as (one element) ListArray.
- fn state(&mut self) -> Result<Vec<ScalarValue>> {
- let scalars = self.values.iter().cloned().collect::<Vec<_>>();
- let arr = ScalarValue::new_list(scalars.as_slice(),
&self.state_data_type);
- Ok(vec![ScalarValue::List(arr)])
- }
-
- fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
- if values.is_empty() {
- return Ok(());
- }
-
- let arr = &values[0];
- if arr.data_type() == &DataType::Null {
- return Ok(());
- }
-
- (0..arr.len()).try_for_each(|index| {
- if !arr.is_null(index) {
- let scalar = ScalarValue::try_from_array(arr, index)?;
- self.values.insert(scalar);
- }
- Ok(())
- })
- }
-
- /// Merges multiple sets of distinct values into the current set.
- ///
- /// The input to this function is a `ListArray` with **multiple** rows,
- /// where each row contains the values from a partial aggregate's phase
(e.g.
- /// the result of calling `Self::state` on multiple accumulators).
- fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
- if states.is_empty() {
- return Ok(());
- }
- assert_eq!(states.len(), 1, "array_agg states must be singleton!");
- let array = &states[0];
- let list_array = array.as_list::<i32>();
- for inner_array in list_array.iter() {
- let Some(inner_array) = inner_array else {
- return internal_err!(
- "Intermediate results of COUNT DISTINCT should always be
non null"
- );
- };
- self.update_batch(&[inner_array])?;
- }
- Ok(())
- }
-
- fn evaluate(&mut self) -> Result<ScalarValue> {
- Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
- }
-
- fn size(&self) -> usize {
- match &self.state_data_type {
- DataType::Boolean | DataType::Null => self.fixed_size(),
- d if d.is_primitive() => self.fixed_size(),
- _ => self.full_size(),
- }
- }
-}
-
-#[cfg(test)]
-mod tests {
- use arrow::array::{
- BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array,
Int64Array,
- Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
- };
- use arrow_array::Decimal256Array;
- use arrow_buffer::i256;
-
- use datafusion_common::cast::{as_boolean_array, as_list_array,
as_primitive_array};
- use datafusion_common::internal_err;
- use datafusion_common::DataFusionError;
-
- use crate::expressions::NoOp;
-
- use super::*;
-
- macro_rules! state_to_vec_primitive {
- ($LIST:expr, $DATA_TYPE:ident) => {{
- let arr = ScalarValue::raw_data($LIST).unwrap();
- let list_arr = as_list_array(&arr).unwrap();
- let arr = list_arr.values();
- let arr = as_primitive_array::<$DATA_TYPE>(arr)?;
- arr.values().iter().cloned().collect::<Vec<_>>()
- }};
- }
-
- macro_rules! test_count_distinct_update_batch_numeric {
- ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{
- let values: Vec<Option<$PRIM_TYPE>> = vec![
- Some(1),
- Some(1),
- None,
- Some(3),
- Some(2),
- None,
- Some(2),
- Some(3),
- Some(1),
- ];
-
- let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef];
-
- let (states, result) = run_update_batch(&arrays)?;
-
- let mut state_vec = state_to_vec_primitive!(&states[0],
$DATA_TYPE);
- state_vec.sort();
-
- assert_eq!(states.len(), 1);
- assert_eq!(state_vec, vec![1, 2, 3]);
- assert_eq!(result, ScalarValue::Int64(Some(3)));
-
- Ok(())
- }};
- }
-
- fn state_to_vec_bool(sv: &ScalarValue) -> Result<Vec<bool>> {
- let arr = ScalarValue::raw_data(sv)?;
- let list_arr = as_list_array(&arr)?;
- let arr = list_arr.values();
- let bool_arr = as_boolean_array(arr)?;
- Ok(bool_arr.iter().flatten().collect())
- }
-
- fn run_update_batch(arrays: &[ArrayRef]) -> Result<(Vec<ScalarValue>,
ScalarValue)> {
- let agg = DistinctCount::new(
- arrays[0].data_type().clone(),
- Arc::new(NoOp::new()),
- String::from("__col_name__"),
- );
-
- let mut accum = agg.create_accumulator()?;
- accum.update_batch(arrays)?;
-
- Ok((accum.state()?, accum.evaluate()?))
- }
-
- fn run_update(
- data_types: &[DataType],
- rows: &[Vec<ScalarValue>],
- ) -> Result<(Vec<ScalarValue>, ScalarValue)> {
- let agg = DistinctCount::new(
- data_types[0].clone(),
- Arc::new(NoOp::new()),
- String::from("__col_name__"),
- );
-
- let mut accum = agg.create_accumulator()?;
-
- let cols = (0..rows[0].len())
- .map(|i| {
- rows.iter()
- .map(|inner| inner[i].clone())
- .collect::<Vec<ScalarValue>>()
- })
- .collect::<Vec<_>>();
-
- let arrays: Vec<ArrayRef> = cols
- .iter()
- .map(|c| ScalarValue::iter_to_array(c.clone()))
- .collect::<Result<Vec<ArrayRef>>>()?;
-
- accum.update_batch(&arrays)?;
-
- Ok((accum.state()?, accum.evaluate()?))
- }
-
- // Used trait to create associated constant for f32 and f64
- trait SubNormal: 'static {
- const SUBNORMAL: Self;
- }
-
- impl SubNormal for f64 {
- const SUBNORMAL: Self = 1.0e-308_f64;
- }
-
- impl SubNormal for f32 {
- const SUBNORMAL: Self = 1.0e-38_f32;
- }
-
- macro_rules! test_count_distinct_update_batch_floating_point {
- ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{
- let values: Vec<Option<$PRIM_TYPE>> = vec![
- Some(<$PRIM_TYPE>::INFINITY),
- Some(<$PRIM_TYPE>::NAN),
- Some(1.0),
- Some(<$PRIM_TYPE as SubNormal>::SUBNORMAL),
- Some(1.0),
- Some(<$PRIM_TYPE>::INFINITY),
- None,
- Some(3.0),
- Some(-4.5),
- Some(2.0),
- None,
- Some(2.0),
- Some(3.0),
- Some(<$PRIM_TYPE>::NEG_INFINITY),
- Some(1.0),
- Some(<$PRIM_TYPE>::NAN),
- Some(<$PRIM_TYPE>::NEG_INFINITY),
- ];
-
- let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef];
-
- let (states, result) = run_update_batch(&arrays)?;
-
- let mut state_vec = state_to_vec_primitive!(&states[0],
$DATA_TYPE);
-
- dbg!(&state_vec);
- state_vec.sort_by(|a, b| match (a, b) {
- (lhs, rhs) => lhs.total_cmp(rhs),
- });
-
- let nan_idx = state_vec.len() - 1;
- assert_eq!(states.len(), 1);
- assert_eq!(
- &state_vec[..nan_idx],
- vec![
- <$PRIM_TYPE>::NEG_INFINITY,
- -4.5,
- <$PRIM_TYPE as SubNormal>::SUBNORMAL,
- 1.0,
- 2.0,
- 3.0,
- <$PRIM_TYPE>::INFINITY
- ]
- );
- assert!(state_vec[nan_idx].is_nan());
- assert_eq!(result, ScalarValue::Int64(Some(8)));
-
- Ok(())
- }};
- }
-
- macro_rules! test_count_distinct_update_batch_bigint {
- ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{
- let values: Vec<Option<$PRIM_TYPE>> = vec![
- Some(i256::from(1)),
- Some(i256::from(1)),
- None,
- Some(i256::from(3)),
- Some(i256::from(2)),
- None,
- Some(i256::from(2)),
- Some(i256::from(3)),
- Some(i256::from(1)),
- ];
-
- let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef];
-
- let (states, result) = run_update_batch(&arrays)?;
-
- let mut state_vec = state_to_vec_primitive!(&states[0],
$DATA_TYPE);
- state_vec.sort();
-
- assert_eq!(states.len(), 1);
- assert_eq!(state_vec, vec![i256::from(1), i256::from(2),
i256::from(3)]);
- assert_eq!(result, ScalarValue::Int64(Some(3)));
-
- Ok(())
- }};
- }
-
- #[test]
- fn count_distinct_update_batch_i8() -> Result<()> {
- test_count_distinct_update_batch_numeric!(Int8Array, Int8Type, i8)
- }
-
- #[test]
- fn count_distinct_update_batch_i16() -> Result<()> {
- test_count_distinct_update_batch_numeric!(Int16Array, Int16Type, i16)
- }
-
- #[test]
- fn count_distinct_update_batch_i32() -> Result<()> {
- test_count_distinct_update_batch_numeric!(Int32Array, Int32Type, i32)
- }
-
- #[test]
- fn count_distinct_update_batch_i64() -> Result<()> {
- test_count_distinct_update_batch_numeric!(Int64Array, Int64Type, i64)
- }
-
- #[test]
- fn count_distinct_update_batch_u8() -> Result<()> {
- test_count_distinct_update_batch_numeric!(UInt8Array, UInt8Type, u8)
- }
-
- #[test]
- fn count_distinct_update_batch_u16() -> Result<()> {
- test_count_distinct_update_batch_numeric!(UInt16Array, UInt16Type, u16)
- }
-
- #[test]
- fn count_distinct_update_batch_u32() -> Result<()> {
- test_count_distinct_update_batch_numeric!(UInt32Array, UInt32Type, u32)
- }
-
- #[test]
- fn count_distinct_update_batch_u64() -> Result<()> {
- test_count_distinct_update_batch_numeric!(UInt64Array, UInt64Type, u64)
- }
-
- #[test]
- fn count_distinct_update_batch_f32() -> Result<()> {
- test_count_distinct_update_batch_floating_point!(Float32Array,
Float32Type, f32)
- }
-
- #[test]
- fn count_distinct_update_batch_f64() -> Result<()> {
- test_count_distinct_update_batch_floating_point!(Float64Array,
Float64Type, f64)
- }
-
- #[test]
- fn count_distinct_update_batch_i256() -> Result<()> {
- test_count_distinct_update_batch_bigint!(Decimal256Array,
Decimal256Type, i256)
- }
-
- #[test]
- fn count_distinct_update_batch_boolean() -> Result<()> {
- let get_count = |data: BooleanArray| -> Result<(Vec<bool>, i64)> {
- let arrays = vec![Arc::new(data) as ArrayRef];
- let (states, result) = run_update_batch(&arrays)?;
- let mut state_vec = state_to_vec_bool(&states[0])?;
- state_vec.sort();
-
- let count = match result {
- ScalarValue::Int64(c) => c.ok_or_else(|| {
- DataFusionError::Internal("Found None count".to_string())
- }),
- scalar => {
- internal_err!("Found non int64 scalar value from count:
{scalar}")
- }
- }?;
- Ok((state_vec, count))
- };
-
- let zero_count_values = BooleanArray::from(Vec::<bool>::new());
-
- let one_count_values = BooleanArray::from(vec![false, false]);
- let one_count_values_with_null =
- BooleanArray::from(vec![Some(true), Some(true), None, None]);
-
- let two_count_values = BooleanArray::from(vec![true, false, true,
false, true]);
- let two_count_values_with_null = BooleanArray::from(vec![
- Some(true),
- Some(false),
- None,
- None,
- Some(true),
- Some(false),
- ]);
-
- assert_eq!(get_count(zero_count_values)?, (Vec::<bool>::new(), 0));
- assert_eq!(get_count(one_count_values)?, (vec![false], 1));
- assert_eq!(get_count(one_count_values_with_null)?, (vec![true], 1));
- assert_eq!(get_count(two_count_values)?, (vec![false, true], 2));
- assert_eq!(
- get_count(two_count_values_with_null)?,
- (vec![false, true], 2)
- );
- Ok(())
- }
-
- #[test]
- fn count_distinct_update_batch_all_nulls() -> Result<()> {
- let arrays = vec![Arc::new(Int32Array::from(
- vec![None, None, None, None] as Vec<Option<i32>>
- )) as ArrayRef];
-
- let (states, result) = run_update_batch(&arrays)?;
- let state_vec = state_to_vec_primitive!(&states[0], Int32Type);
- assert_eq!(states.len(), 1);
- assert!(state_vec.is_empty());
- assert_eq!(result, ScalarValue::Int64(Some(0)));
-
- Ok(())
- }
-
- #[test]
- fn count_distinct_update_batch_empty() -> Result<()> {
- let arrays = vec![Arc::new(Int32Array::from(vec![0_i32; 0])) as
ArrayRef];
-
- let (states, result) = run_update_batch(&arrays)?;
- let state_vec = state_to_vec_primitive!(&states[0], Int32Type);
- assert_eq!(states.len(), 1);
- assert!(state_vec.is_empty());
- assert_eq!(result, ScalarValue::Int64(Some(0)));
-
- Ok(())
- }
-
- #[test]
- fn count_distinct_update() -> Result<()> {
- let (states, result) = run_update(
- &[DataType::Int32],
- &[
- vec![ScalarValue::Int32(Some(-1))],
- vec![ScalarValue::Int32(Some(5))],
- vec![ScalarValue::Int32(Some(-1))],
- vec![ScalarValue::Int32(Some(5))],
- vec![ScalarValue::Int32(Some(-1))],
- vec![ScalarValue::Int32(Some(-1))],
- vec![ScalarValue::Int32(Some(2))],
- ],
- )?;
- assert_eq!(states.len(), 1);
- assert_eq!(result, ScalarValue::Int64(Some(3)));
-
- let (states, result) = run_update(
- &[DataType::UInt64],
- &[
- vec![ScalarValue::UInt64(Some(1))],
- vec![ScalarValue::UInt64(Some(5))],
- vec![ScalarValue::UInt64(Some(1))],
- vec![ScalarValue::UInt64(Some(5))],
- vec![ScalarValue::UInt64(Some(1))],
- vec![ScalarValue::UInt64(Some(1))],
- vec![ScalarValue::UInt64(Some(2))],
- ],
- )?;
- assert_eq!(states.len(), 1);
- assert_eq!(result, ScalarValue::Int64(Some(3)));
- Ok(())
- }
-
- #[test]
- fn count_distinct_update_with_nulls() -> Result<()> {
- let (states, result) = run_update(
- &[DataType::Int32],
- &[
- // None of these updates contains a None, so these are
accumulated.
- vec![ScalarValue::Int32(Some(-1))],
- vec![ScalarValue::Int32(Some(-1))],
- vec![ScalarValue::Int32(Some(-2))],
- // Each of these updates contains at least one None, so these
- // won't be accumulated.
- vec![ScalarValue::Int32(Some(-1))],
- vec![ScalarValue::Int32(None)],
- vec![ScalarValue::Int32(None)],
- ],
- )?;
- assert_eq!(states.len(), 1);
- assert_eq!(result, ScalarValue::Int64(Some(2)));
-
- let (states, result) = run_update(
- &[DataType::UInt64],
- &[
- // None of these updates contains a None, so these are
accumulated.
- vec![ScalarValue::UInt64(Some(1))],
- vec![ScalarValue::UInt64(Some(1))],
- vec![ScalarValue::UInt64(Some(2))],
- // Each of these updates contains at least one None, so these
- // won't be accumulated.
- vec![ScalarValue::UInt64(Some(1))],
- vec![ScalarValue::UInt64(None)],
- vec![ScalarValue::UInt64(None)],
- ],
- )?;
- assert_eq!(states.len(), 1);
- assert_eq!(result, ScalarValue::Int64(Some(2)));
- Ok(())
- }
-}
diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs
b/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs
index 65227b727b..a6946e739c 100644
--- a/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs
+++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs
@@ -20,7 +20,7 @@ pub use adapter::GroupsAccumulatorAdapter;
// Backward compatibility
pub(crate) mod accumulate {
- pub use
datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::{accumulate_indices,
NullState};
+ pub use
datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::NullState;
}
pub use
datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::NullState;
diff --git a/datafusion/physical-expr/src/aggregate/mod.rs
b/datafusion/physical-expr/src/aggregate/mod.rs
index 7a6c5f9d0e..01105c8559 100644
--- a/datafusion/physical-expr/src/aggregate/mod.rs
+++ b/datafusion/physical-expr/src/aggregate/mod.rs
@@ -26,8 +26,6 @@ pub(crate) mod average;
pub(crate) mod bit_and_or_xor;
pub(crate) mod bool_and_or;
pub(crate) mod correlation;
-pub(crate) mod count;
-pub(crate) mod count_distinct;
pub(crate) mod covariance;
pub(crate) mod grouping;
pub(crate) mod nth_value;
diff --git a/datafusion/physical-expr/src/expressions/mod.rs
b/datafusion/physical-expr/src/expressions/mod.rs
index a96d021730..123ada6d7c 100644
--- a/datafusion/physical-expr/src/expressions/mod.rs
+++ b/datafusion/physical-expr/src/expressions/mod.rs
@@ -47,8 +47,6 @@ pub use crate::aggregate::bit_and_or_xor::{BitAnd, BitOr,
BitXor, DistinctBitXor
pub use crate::aggregate::bool_and_or::{BoolAnd, BoolOr};
pub use crate::aggregate::build_in::create_aggregate_expr;
pub use crate::aggregate::correlation::Correlation;
-pub use crate::aggregate::count::Count;
-pub use crate::aggregate::count_distinct::DistinctCount;
pub use crate::aggregate::grouping::Grouping;
pub use crate::aggregate::min_max::{Max, MaxAccumulator, Min, MinAccumulator};
pub use crate::aggregate::nth_value::NthValueAgg;
diff --git a/datafusion/physical-expr/src/lib.rs
b/datafusion/physical-expr/src/lib.rs
index 72f5f2d50c..b764e81a95 100644
--- a/datafusion/physical-expr/src/lib.rs
+++ b/datafusion/physical-expr/src/lib.rs
@@ -17,7 +17,9 @@
pub mod aggregate;
pub mod analysis;
-pub mod binary_map;
+pub mod binary_map {
+ pub use datafusion_physical_expr_common::binary_map::{ArrowBytesSet,
OutputType};
+}
pub mod equivalence;
pub mod expressions;
pub mod functions;
diff --git a/datafusion/physical-plan/src/aggregates/group_values/bytes.rs
b/datafusion/physical-plan/src/aggregates/group_values/bytes.rs
index d073c8995a..f789af8b8a 100644
--- a/datafusion/physical-plan/src/aggregates/group_values/bytes.rs
+++ b/datafusion/physical-plan/src/aggregates/group_values/bytes.rs
@@ -18,7 +18,7 @@
use crate::aggregates::group_values::GroupValues;
use arrow_array::{Array, ArrayRef, OffsetSizeTrait, RecordBatch};
use datafusion_expr::EmitTo;
-use datafusion_physical_expr::binary_map::{ArrowBytesMap, OutputType};
+use datafusion_physical_expr_common::binary_map::{ArrowBytesMap, OutputType};
/// A [`GroupValues`] storing single column of
Utf8/LargeUtf8/Binary/LargeBinary values
///
diff --git a/datafusion/physical-plan/src/aggregates/mod.rs
b/datafusion/physical-plan/src/aggregates/mod.rs
index 79abbdb52c..b6fc70be7c 100644
--- a/datafusion/physical-plan/src/aggregates/mod.rs
+++ b/datafusion/physical-plan/src/aggregates/mod.rs
@@ -1194,12 +1194,14 @@ mod tests {
use datafusion_execution::memory_pool::FairSpillPool;
use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use datafusion_expr::expr::Sort;
+ use datafusion_functions_aggregate::count::count_udaf;
use datafusion_functions_aggregate::median::median_udaf;
use datafusion_physical_expr::expressions::{
- lit, Count, FirstValue, LastValue, OrderSensitiveArrayAgg,
+ lit, FirstValue, LastValue, OrderSensitiveArrayAgg,
};
use datafusion_physical_expr::PhysicalSortExpr;
+ use datafusion_physical_expr_common::aggregate::create_aggregate_expr;
use futures::{FutureExt, Stream};
// Generate a schema which consists of 5 columns (a, b, c, d, e)
@@ -1334,11 +1336,16 @@ mod tests {
],
};
- let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Count::new(
- lit(1i8),
- "COUNT(1)".to_string(),
- DataType::Int64,
- ))];
+ let aggregates = vec![create_aggregate_expr(
+ &count_udaf(),
+ &[lit(1i8)],
+ &[],
+ &[],
+ &input_schema,
+ "COUNT(1)",
+ false,
+ false,
+ )?];
let task_ctx = if spill {
new_spill_ctx(4, 1000)
diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs
b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs
index 48f1bee59b..56d780e513 100644
--- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs
+++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs
@@ -1194,9 +1194,9 @@ mod tests {
RecordBatchStream, SendableRecordBatchStream, TaskContext,
};
use datafusion_expr::{
- AggregateFunction, WindowFrame, WindowFrameBound, WindowFrameUnits,
- WindowFunctionDefinition,
+ WindowFrame, WindowFrameBound, WindowFrameUnits,
WindowFunctionDefinition,
};
+ use datafusion_functions_aggregate::count::count_udaf;
use datafusion_physical_expr::expressions::{col, Column, NthValue};
use datafusion_physical_expr::window::{
BuiltInWindowExpr, BuiltInWindowFunctionExpr,
@@ -1298,8 +1298,7 @@ mod tests {
order_by: &str,
) -> Result<Arc<dyn ExecutionPlan>> {
let schema = input.schema();
- let window_fn =
-
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count);
+ let window_fn = WindowFunctionDefinition::AggregateUDF(count_udaf());
let col_expr =
Arc::new(Column::new(schema.fields[0].name(), 0)) as Arc<dyn
PhysicalExpr>;
let args = vec![col_expr];
diff --git a/datafusion/physical-plan/src/windows/mod.rs
b/datafusion/physical-plan/src/windows/mod.rs
index 9b392d941e..63ce473fc5 100644
--- a/datafusion/physical-plan/src/windows/mod.rs
+++ b/datafusion/physical-plan/src/windows/mod.rs
@@ -597,7 +597,6 @@ pub fn get_window_mode(
#[cfg(test)]
mod tests {
use super::*;
- use crate::aggregates::AggregateFunction;
use crate::collect;
use crate::expressions::col;
use crate::streaming::StreamingTableExec;
@@ -607,6 +606,7 @@ mod tests {
use arrow::compute::SortOptions;
use datafusion_execution::TaskContext;
+ use datafusion_functions_aggregate::count::count_udaf;
use futures::FutureExt;
use InputOrderMode::{Linear, PartiallySorted, Sorted};
@@ -749,7 +749,7 @@ mod tests {
let refs = blocking_exec.refs();
let window_agg_exec = Arc::new(WindowAggExec::try_new(
vec![create_window_expr(
-
&WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count),
+ &WindowFunctionDefinition::AggregateUDF(count_udaf()),
"count".to_owned(),
&[col("a", &schema)?],
&[],
diff --git a/datafusion/proto-common/Cargo.toml
b/datafusion/proto-common/Cargo.toml
index 97568fb5f6..66ce7cbd83 100644
--- a/datafusion/proto-common/Cargo.toml
+++ b/datafusion/proto-common/Cargo.toml
@@ -26,7 +26,7 @@ homepage = { workspace = true }
repository = { workspace = true }
license = { workspace = true }
authors = { workspace = true }
-rust-version = "1.73"
+rust-version = "1.75"
# Exclude proto files so crates.io consumers don't need protoc
exclude = ["*.proto"]
diff --git a/datafusion/proto-common/gen/Cargo.toml
b/datafusion/proto-common/gen/Cargo.toml
index 49884c48b3..9f8f03de6d 100644
--- a/datafusion/proto-common/gen/Cargo.toml
+++ b/datafusion/proto-common/gen/Cargo.toml
@@ -20,7 +20,7 @@ name = "gen-common"
description = "Code generation for proto"
version = "0.1.0"
edition = { workspace = true }
-rust-version = "1.73"
+rust-version = "1.75"
authors = { workspace = true }
homepage = { workspace = true }
repository = { workspace = true }
diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml
index 358ba7e3eb..b1897aa58e 100644
--- a/datafusion/proto/Cargo.toml
+++ b/datafusion/proto/Cargo.toml
@@ -27,7 +27,7 @@ repository = { workspace = true }
license = { workspace = true }
authors = { workspace = true }
# Specify MSRV here as `cargo msrv` doesn't support workspace version
-rust-version = "1.73"
+rust-version = "1.75"
# Exclude proto files so crates.io consumers don't need protoc
exclude = ["*.proto"]
diff --git a/datafusion/proto/gen/Cargo.toml b/datafusion/proto/gen/Cargo.toml
index b6993f6c04..eabaf7ba8e 100644
--- a/datafusion/proto/gen/Cargo.toml
+++ b/datafusion/proto/gen/Cargo.toml
@@ -20,7 +20,7 @@ name = "gen"
description = "Code generation for proto"
version = "0.1.0"
edition = { workspace = true }
-rust-version = "1.73"
+rust-version = "1.75"
authors = { workspace = true }
homepage = { workspace = true }
repository = { workspace = true }
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index b401ff8810..2bb3ec793d 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -520,6 +520,7 @@ message AggregateExprNode {
message AggregateUDFExprNode {
string fun_name = 1;
repeated LogicalExprNode args = 2;
+ bool distinct = 5;
LogicalExprNode filter = 3;
repeated LogicalExprNode order_by = 4;
}
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index d6632c77d8..59b7861a6e 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -886,6 +886,9 @@ impl serde::Serialize for AggregateUdfExprNode {
if !self.args.is_empty() {
len += 1;
}
+ if self.distinct {
+ len += 1;
+ }
if self.filter.is_some() {
len += 1;
}
@@ -899,6 +902,9 @@ impl serde::Serialize for AggregateUdfExprNode {
if !self.args.is_empty() {
struct_ser.serialize_field("args", &self.args)?;
}
+ if self.distinct {
+ struct_ser.serialize_field("distinct", &self.distinct)?;
+ }
if let Some(v) = self.filter.as_ref() {
struct_ser.serialize_field("filter", v)?;
}
@@ -918,6 +924,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode {
"fun_name",
"funName",
"args",
+ "distinct",
"filter",
"order_by",
"orderBy",
@@ -927,6 +934,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode {
enum GeneratedField {
FunName,
Args,
+ Distinct,
Filter,
OrderBy,
}
@@ -952,6 +960,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode {
match value {
"funName" | "fun_name" =>
Ok(GeneratedField::FunName),
"args" => Ok(GeneratedField::Args),
+ "distinct" => Ok(GeneratedField::Distinct),
"filter" => Ok(GeneratedField::Filter),
"orderBy" | "order_by" =>
Ok(GeneratedField::OrderBy),
_ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
@@ -975,6 +984,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode {
{
let mut fun_name__ = None;
let mut args__ = None;
+ let mut distinct__ = None;
let mut filter__ = None;
let mut order_by__ = None;
while let Some(k) = map_.next_key()? {
@@ -991,6 +1001,12 @@ impl<'de> serde::Deserialize<'de> for
AggregateUdfExprNode {
}
args__ = Some(map_.next_value()?);
}
+ GeneratedField::Distinct => {
+ if distinct__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("distinct"));
+ }
+ distinct__ = Some(map_.next_value()?);
+ }
GeneratedField::Filter => {
if filter__.is_some() {
return
Err(serde::de::Error::duplicate_field("filter"));
@@ -1008,6 +1024,7 @@ impl<'de> serde::Deserialize<'de> for
AggregateUdfExprNode {
Ok(AggregateUdfExprNode {
fun_name: fun_name__.unwrap_or_default(),
args: args__.unwrap_or_default(),
+ distinct: distinct__.unwrap_or_default(),
filter: filter__,
order_by: order_by__.unwrap_or_default(),
})
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index 0aca5ef1ff..0861c287fc 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -767,6 +767,8 @@ pub struct AggregateUdfExprNode {
pub fun_name: ::prost::alloc::string::String,
#[prost(message, repeated, tag = "2")]
pub args: ::prost::alloc::vec::Vec<LogicalExprNode>,
+ #[prost(bool, tag = "5")]
+ pub distinct: bool,
#[prost(message, optional, boxed, tag = "3")]
pub filter:
::core::option::Option<::prost::alloc::boxed::Box<LogicalExprNode>>,
#[prost(message, repeated, tag = "4")]
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs
b/datafusion/proto/src/logical_plan/from_proto.rs
index 3ad5973380..2ad40d883f 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -642,7 +642,7 @@ pub fn parse_expr(
Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf(
agg_fn,
parse_exprs(&pb.args, registry, codec)?,
- false,
+ pb.distinct,
parse_optional_expr(pb.filter.as_deref(), registry,
codec)?.map(Box::new),
parse_vec_expr(&pb.order_by, registry, codec)?,
None,
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs
b/datafusion/proto/src/logical_plan/to_proto.rs
index d42470f198..6a275ed7a1 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -456,6 +456,7 @@ pub fn serialize_expr(
protobuf::AggregateUdfExprNode {
fun_name: fun.name().to_string(),
args: serialize_exprs(args, codec)?,
+ distinct: *distinct,
filter: match filter {
Some(e) =>
Some(Box::new(serialize_expr(e.as_ref(), codec)?)),
None => None,
diff --git a/datafusion/proto/src/physical_plan/to_proto.rs
b/datafusion/proto/src/physical_plan/to_proto.rs
index 5258bdd11d..e25447b023 100644
--- a/datafusion/proto/src/physical_plan/to_proto.rs
+++ b/datafusion/proto/src/physical_plan/to_proto.rs
@@ -25,10 +25,10 @@ use datafusion::physical_expr::{PhysicalSortExpr,
ScalarFunctionExpr};
use datafusion::physical_plan::expressions::{
ApproxPercentileCont, ApproxPercentileContWithWeight, ArrayAgg, Avg,
BinaryExpr,
BitAnd, BitOr, BitXor, BoolAnd, BoolOr, CaseExpr, CastExpr, Column,
Correlation,
- Count, CumeDist, DistinctArrayAgg, DistinctBitXor, DistinctCount, Grouping,
- InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr,
NotExpr,
- NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, RankType,
Regr, RegrType,
- RowNumber, StringAgg, TryCastExpr, WindowShift,
+ CumeDist, DistinctArrayAgg, DistinctBitXor, Grouping, InListExpr,
IsNotNullExpr,
+ IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, NthValue,
NthValueAgg, Ntile,
+ OrderSensitiveArrayAgg, Rank, RankType, Regr, RegrType, RowNumber,
StringAgg,
+ TryCastExpr, WindowShift,
};
use datafusion::physical_plan::udaf::AggregateFunctionExpr;
use datafusion::physical_plan::windows::{BuiltInWindowExpr,
PlainAggregateWindowExpr};
@@ -240,12 +240,7 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) ->
Result<AggrFn> {
let aggr_expr = expr.as_any();
let mut distinct = false;
- let inner = if aggr_expr.downcast_ref::<Count>().is_some() {
- protobuf::AggregateFunction::Count
- } else if aggr_expr.downcast_ref::<DistinctCount>().is_some() {
- distinct = true;
- protobuf::AggregateFunction::Count
- } else if aggr_expr.downcast_ref::<Grouping>().is_some() {
+ let inner = if aggr_expr.downcast_ref::<Grouping>().is_some() {
protobuf::AggregateFunction::Grouping
} else if aggr_expr.downcast_ref::<BitAnd>().is_some() {
protobuf::AggregateFunction::BitAnd
diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
index 699697dd2f..d9736da69d 100644
--- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
@@ -649,6 +649,8 @@ async fn roundtrip_expr_api() -> Result<()> {
lit(1),
),
array_replace_all(make_array(vec![lit(1), lit(2), lit(3)]), lit(2),
lit(4)),
+ count(lit(1)),
+ count_distinct(lit(1)),
first_value(lit(1), None),
first_value(lit(1), Some(vec![lit(2).sort(true, true)])),
covar_samp(lit(1.5), lit(2.2)),
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index 9cf686dbd3..e517482f1d 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -38,7 +38,7 @@ use datafusion::datasource::physical_plan::{
};
use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::{create_udf, JoinType, Operator, Volatility};
-use datafusion::physical_expr::expressions::{Count, Max, NthValueAgg};
+use datafusion::physical_expr::expressions::{Max, NthValueAgg};
use datafusion::physical_expr::window::SlidingAggregateWindowExpr;
use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr};
use datafusion::physical_plan::aggregates::{
@@ -47,8 +47,8 @@ use datafusion::physical_plan::aggregates::{
use datafusion::physical_plan::analyze::AnalyzeExec;
use datafusion::physical_plan::empty::EmptyExec;
use datafusion::physical_plan::expressions::{
- binary, cast, col, in_list, like, lit, Avg, BinaryExpr, Column,
DistinctCount,
- NotExpr, NthValue, PhysicalSortExpr, StringAgg,
+ binary, cast, col, in_list, like, lit, Avg, BinaryExpr, Column, NotExpr,
NthValue,
+ PhysicalSortExpr, StringAgg,
};
use datafusion::physical_plan::filter::FilterExec;
use datafusion::physical_plan::insert::DataSinkExec;
@@ -806,7 +806,7 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> {
let aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Final,
PhysicalGroupBy::new(vec![], vec![], vec![]),
- vec![Arc::new(Count::new(udf_expr, "count", DataType::Int64))],
+ vec![Arc::new(Max::new(udf_expr, "max", DataType::Int64))],
vec![None],
window,
schema.clone(),
@@ -818,31 +818,6 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> {
Ok(())
}
-#[test]
-fn roundtrip_distinct_count() -> Result<()> {
- let field_a = Field::new("a", DataType::Int64, false);
- let field_b = Field::new("b", DataType::Int64, false);
- let schema = Arc::new(Schema::new(vec![field_a, field_b]));
-
- let aggregates: Vec<Arc<dyn AggregateExpr>> =
vec![Arc::new(DistinctCount::new(
- DataType::Int64,
- col("b", &schema)?,
- "COUNT(DISTINCT b)".to_string(),
- ))];
-
- let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
- vec![(col("a", &schema)?, "unused".to_string())];
-
- roundtrip_test(Arc::new(AggregateExec::try_new(
- AggregateMode::Final,
- PhysicalGroupBy::new_single(groups),
- aggregates.clone(),
- vec![None],
- Arc::new(EmptyExec::new(schema.clone())),
- schema,
- )?))
-}
-
#[test]
fn roundtrip_like() -> Result<()> {
let schema = Schema::new(vec![
diff --git a/datafusion/sqllogictest/test_files/errors.slt
b/datafusion/sqllogictest/test_files/errors.slt
index e930af107f..c7b9808c24 100644
--- a/datafusion/sqllogictest/test_files/errors.slt
+++ b/datafusion/sqllogictest/test_files/errors.slt
@@ -46,7 +46,7 @@ statement error DataFusion error: Arrow error: Cast error:
Cannot cast string 'c
SELECT CAST(c1 AS INT) FROM aggregate_test_100
# aggregation_with_bad_arguments
-statement error DataFusion error: SQL error: ParserError\("Expected an
expression:, found: \)"\)
+query error
SELECT COUNT(DISTINCT) FROM aggregate_test_100
# query_cte_incorrect
@@ -104,7 +104,7 @@ SELECT power(1, 2, 3);
#
# AggregateFunction with wrong number of arguments
-statement error DataFusion error: Error during planning: No function matches
the given name and argument types 'COUNT\(\)'\. You might need to add explicit
type casts\.\n\tCandidate functions:\n\tCOUNT\(Any, \.\., Any\)
+query error
select count();
# AggregateFunction with wrong number of arguments
diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml
index ee96ffa670..d934dba4cf 100644
--- a/datafusion/substrait/Cargo.toml
+++ b/datafusion/substrait/Cargo.toml
@@ -26,7 +26,7 @@ repository = { workspace = true }
license = { workspace = true }
authors = { workspace = true }
# Specify MSRV here as `cargo msrv` doesn't support workspace version
-rust-version = "1.73"
+rust-version = "1.75"
[lints]
workspace = true
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs
b/datafusion/substrait/src/logical_plan/consumer.rs
index 3f9a895d95..93f197885c 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -982,18 +982,16 @@ pub async fn from_substrait_agg_func(
let function_name = substrait_fun_name((**function_name).as_str());
// try udaf first, then built-in aggr fn.
if let Ok(fun) = ctx.udaf(function_name) {
+ // deal with situation that count(*) got no arguments
+ if fun.name() == "COUNT" && args.is_empty() {
+ args.push(Expr::Literal(ScalarValue::Int64(Some(1))));
+ }
+
Ok(Arc::new(Expr::AggregateFunction(
expr::AggregateFunction::new_udf(fun, args, distinct, filter,
order_by, None),
)))
} else if let Ok(fun) =
aggregate_function::AggregateFunction::from_str(function_name)
{
- match &fun {
- // deal with situation that count(*) got no arguments
- aggregate_function::AggregateFunction::Count if args.is_empty() =>
{
- args.push(Expr::Literal(ScalarValue::Int64(Some(1))));
- }
- _ => {}
- }
Ok(Arc::new(Expr::AggregateFunction(
expr::AggregateFunction::new(fun, args, distinct, filter,
order_by, None),
)))
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]