This is an automated email from the ASF dual-hosted git repository.

yjshen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new dafe99733e feat: Support SQL filter clause for aggregate expressions, 
add SQL dialect support (#5868)
dafe99733e is described below

commit dafe99733e0f97bfb5ef750f02d02abcb641682d
Author: Yijie Shen <[email protected]>
AuthorDate: Wed Apr 12 02:09:23 2023 +0800

    feat: Support SQL filter clause for aggregate expressions, add SQL dialect 
support (#5868)
---
 datafusion/common/src/config.rs                    |   4 +
 datafusion/core/src/execution/context.rs           |  35 +++-
 .../src/physical_optimizer/aggregate_statistics.rs |  13 ++
 .../src/physical_optimizer/dist_enforcement.rs     |   8 +
 .../core/src/physical_optimizer/repartition.rs     |   2 +
 .../src/physical_optimizer/sort_enforcement.rs     |   1 +
 .../core/src/physical_plan/aggregates/mod.rs       |  33 ++++
 .../src/physical_plan/aggregates/no_grouping.rs    |  35 +++-
 .../core/src/physical_plan/aggregates/row_hash.rs  | 211 +++++++++++++--------
 datafusion/core/src/physical_plan/filter.rs        |   2 +-
 datafusion/core/src/physical_plan/planner.rs       |  57 ++++--
 .../tests/sqllogictests/test_files/aggregate.slt   |  80 ++++++++
 .../test_files/information_schema.slt              |   1 +
 datafusion/expr/src/tree_node/expr.rs              |   4 +-
 datafusion/optimizer/src/push_down_projection.rs   |   2 +-
 datafusion/proto/proto/datafusion.proto            |   5 +
 datafusion/proto/src/generated/pbjson.rs           | 109 +++++++++++
 datafusion/proto/src/generated/prost.rs            |   8 +
 datafusion/proto/src/physical_plan/mod.rs          |  22 +++
 datafusion/proto/src/physical_plan/to_proto.rs     |  13 ++
 datafusion/sql/tests/integration_test.rs           |   4 +-
 docs/source/user-guide/configs.md                  |   1 +
 22 files changed, 548 insertions(+), 102 deletions(-)

diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs
index 55cdc36d20..5973bf262e 100644
--- a/datafusion/common/src/config.rs
+++ b/datafusion/common/src/config.rs
@@ -187,6 +187,10 @@ config_namespace! {
         /// When set to true, SQL parser will normalize ident (convert ident 
to lowercase when not quoted)
         pub enable_ident_normalization: bool, default = true
 
+        /// Configure the SQL dialect used by DataFusion's parser; supported 
values include: Generic,
+        /// MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, 
ClickHouse, BigQuery, and Ansi.
+        pub dialect: String, default = "generic".to_string()
+
     }
 }
 
diff --git a/datafusion/core/src/execution/context.rs 
b/datafusion/core/src/execution/context.rs
index c3adb4cc74..5114adec72 100644
--- a/datafusion/core/src/execution/context.rs
+++ b/datafusion/core/src/execution/context.rs
@@ -91,6 +91,11 @@ use datafusion_sql::{
     planner::{ContextProvider, SqlToRel},
 };
 use parquet::file::properties::WriterProperties;
+use sqlparser::dialect::{
+    AnsiDialect, BigQueryDialect, ClickHouseDialect, Dialect, GenericDialect,
+    HiveDialect, MsSqlDialect, MySqlDialect, PostgreSqlDialect, 
RedshiftSqlDialect,
+    SQLiteDialect, SnowflakeDialect,
+};
 use url::Url;
 
 use crate::catalog::information_schema::{InformationSchemaProvider, 
INFORMATION_SCHEMA};
@@ -1500,8 +1505,10 @@ impl SessionState {
     pub fn sql_to_statement(
         &self,
         sql: &str,
+        dialect: &str,
     ) -> Result<datafusion_sql::parser::Statement> {
-        let mut statements = DFParser::parse_sql(sql)?;
+        let dialect = create_dialect_from_str(dialect)?;
+        let mut statements = DFParser::parse_sql_with_dialect(sql, 
dialect.as_ref())?;
         if statements.len() > 1 {
             return Err(DataFusionError::NotImplemented(
                 "The context currently only supports a single SQL 
statement".to_string(),
@@ -1629,7 +1636,8 @@ impl SessionState {
     ///
     /// See [`SessionContext::sql`] for a higher-level interface that also 
handles DDL
     pub async fn create_logical_plan(&self, sql: &str) -> Result<LogicalPlan> {
-        let statement = self.sql_to_statement(sql)?;
+        let dialect = self.config.options().sql_parser.dialect.as_str();
+        let statement = self.sql_to_statement(sql, dialect)?;
         let plan = self.statement_to_plan(statement).await?;
         Ok(plan)
     }
@@ -1838,6 +1846,29 @@ impl From<&SessionState> for TaskContext {
     }
 }
 
+// TODO: remove when https://github.com/sqlparser-rs/sqlparser-rs/pull/848 is 
released
+fn create_dialect_from_str(dialect_name: &str) -> Result<Box<dyn Dialect>> {
+    match dialect_name.to_lowercase().as_str() {
+        "generic" => Ok(Box::new(GenericDialect)),
+        "mysql" => Ok(Box::new(MySqlDialect {})),
+        "postgresql" | "postgres" => Ok(Box::new(PostgreSqlDialect {})),
+        "hive" => Ok(Box::new(HiveDialect {})),
+        "sqlite" => Ok(Box::new(SQLiteDialect {})),
+        "snowflake" => Ok(Box::new(SnowflakeDialect)),
+        "redshift" => Ok(Box::new(RedshiftSqlDialect {})),
+        "mssql" => Ok(Box::new(MsSqlDialect {})),
+        "clickhouse" => Ok(Box::new(ClickHouseDialect {})),
+        "bigquery" => Ok(Box::new(BigQueryDialect)),
+        "ansi" => Ok(Box::new(AnsiDialect {})),
+        _ => {
+            Err(DataFusionError::Internal(format!(
+                "Unsupported SQL dialect: {}. Available dialects: Generic, 
MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, 
BigQuery, Ansi.",
+                dialect_name
+            )))
+        }
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs 
b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
index 59806a0a2f..b88f73d8c2 100644
--- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
+++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
@@ -123,6 +123,7 @@ fn take_optimizable(node: &dyn ExecutionPlan) -> 
Option<Arc<dyn ExecutionPlan>>
                 {
                     if partial_agg_exec.mode() == &AggregateMode::Partial
                         && partial_agg_exec.group_expr().is_empty()
+                        && partial_agg_exec.filter_expr().iter().all(|e| 
e.is_none())
                     {
                         let stats = partial_agg_exec.input().statistics();
                         if stats.is_exact {
@@ -410,6 +411,7 @@ mod tests {
             AggregateMode::Partial,
             PhysicalGroupBy::default(),
             vec![agg.count_expr()],
+            vec![None],
             source,
             Arc::clone(&schema),
         )?;
@@ -418,6 +420,7 @@ mod tests {
             AggregateMode::Final,
             PhysicalGroupBy::default(),
             vec![agg.count_expr()],
+            vec![None],
             Arc::new(partial_agg),
             Arc::clone(&schema),
         )?;
@@ -438,6 +441,7 @@ mod tests {
             AggregateMode::Partial,
             PhysicalGroupBy::default(),
             vec![agg.count_expr()],
+            vec![None],
             source,
             Arc::clone(&schema),
         )?;
@@ -446,6 +450,7 @@ mod tests {
             AggregateMode::Final,
             PhysicalGroupBy::default(),
             vec![agg.count_expr()],
+            vec![None],
             Arc::new(partial_agg),
             Arc::clone(&schema),
         )?;
@@ -465,6 +470,7 @@ mod tests {
             AggregateMode::Partial,
             PhysicalGroupBy::default(),
             vec![agg.count_expr()],
+            vec![None],
             source,
             Arc::clone(&schema),
         )?;
@@ -476,6 +482,7 @@ mod tests {
             AggregateMode::Final,
             PhysicalGroupBy::default(),
             vec![agg.count_expr()],
+            vec![None],
             Arc::new(coalesce),
             Arc::clone(&schema),
         )?;
@@ -495,6 +502,7 @@ mod tests {
             AggregateMode::Partial,
             PhysicalGroupBy::default(),
             vec![agg.count_expr()],
+            vec![None],
             source,
             Arc::clone(&schema),
         )?;
@@ -506,6 +514,7 @@ mod tests {
             AggregateMode::Final,
             PhysicalGroupBy::default(),
             vec![agg.count_expr()],
+            vec![None],
             Arc::new(coalesce),
             Arc::clone(&schema),
         )?;
@@ -536,6 +545,7 @@ mod tests {
             AggregateMode::Partial,
             PhysicalGroupBy::default(),
             vec![agg.count_expr()],
+            vec![None],
             filter,
             Arc::clone(&schema),
         )?;
@@ -544,6 +554,7 @@ mod tests {
             AggregateMode::Final,
             PhysicalGroupBy::default(),
             vec![agg.count_expr()],
+            vec![None],
             Arc::new(partial_agg),
             Arc::clone(&schema),
         )?;
@@ -579,6 +590,7 @@ mod tests {
             AggregateMode::Partial,
             PhysicalGroupBy::default(),
             vec![agg.count_expr()],
+            vec![None],
             filter,
             Arc::clone(&schema),
         )?;
@@ -587,6 +599,7 @@ mod tests {
             AggregateMode::Final,
             PhysicalGroupBy::default(),
             vec![agg.count_expr()],
+            vec![None],
             Arc::new(partial_agg),
             Arc::clone(&schema),
         )?;
diff --git a/datafusion/core/src/physical_optimizer/dist_enforcement.rs 
b/datafusion/core/src/physical_optimizer/dist_enforcement.rs
index d3e99945e9..affe432830 100644
--- a/datafusion/core/src/physical_optimizer/dist_enforcement.rs
+++ b/datafusion/core/src/physical_optimizer/dist_enforcement.rs
@@ -252,6 +252,7 @@ fn adjust_input_keys_ordering(
         mode,
         group_by,
         aggr_expr,
+        filter_expr,
         input,
         input_schema,
         ..
@@ -264,6 +265,7 @@ fn adjust_input_keys_ordering(
                     &parent_required,
                     group_by,
                     aggr_expr,
+                    filter_expr,
                     input.clone(),
                     input_schema,
                 )?),
@@ -369,6 +371,7 @@ fn reorder_aggregate_keys(
     parent_required: &[Arc<dyn PhysicalExpr>],
     group_by: &PhysicalGroupBy,
     aggr_expr: &[Arc<dyn AggregateExpr>],
+    filter_expr: &[Option<Arc<dyn PhysicalExpr>>],
     agg_input: Arc<dyn ExecutionPlan>,
     input_schema: &SchemaRef,
 ) -> Result<PlanWithKeyRequirements> {
@@ -398,6 +401,7 @@ fn reorder_aggregate_keys(
                     mode,
                     group_by,
                     aggr_expr,
+                    filter_expr,
                     input,
                     input_schema,
                     ..
@@ -416,6 +420,7 @@ fn reorder_aggregate_keys(
                             AggregateMode::Partial,
                             new_partial_group_by,
                             aggr_expr.clone(),
+                            filter_expr.clone(),
                             input.clone(),
                             input_schema.clone(),
                         )?))
@@ -446,6 +451,7 @@ fn reorder_aggregate_keys(
                         AggregateMode::FinalPartitioned,
                         new_group_by,
                         aggr_expr.to_vec(),
+                        filter_expr.to_vec(),
                         partial_agg,
                         input_schema.clone(),
                     )?);
@@ -1067,11 +1073,13 @@ mod tests {
                 AggregateMode::FinalPartitioned,
                 final_grouping,
                 vec![],
+                vec![],
                 Arc::new(
                     AggregateExec::try_new(
                         AggregateMode::Partial,
                         group_by,
                         vec![],
+                        vec![],
                         input,
                         schema.clone(),
                     )
diff --git a/datafusion/core/src/physical_optimizer/repartition.rs 
b/datafusion/core/src/physical_optimizer/repartition.rs
index 3bb21b12be..1db61e379e 100644
--- a/datafusion/core/src/physical_optimizer/repartition.rs
+++ b/datafusion/core/src/physical_optimizer/repartition.rs
@@ -477,11 +477,13 @@ mod tests {
                 AggregateMode::Final,
                 PhysicalGroupBy::default(),
                 vec![],
+                vec![],
                 Arc::new(
                     AggregateExec::try_new(
                         AggregateMode::Partial,
                         PhysicalGroupBy::default(),
                         vec![],
+                        vec![],
                         input,
                         schema.clone(),
                     )
diff --git a/datafusion/core/src/physical_optimizer/sort_enforcement.rs 
b/datafusion/core/src/physical_optimizer/sort_enforcement.rs
index b1a4da65e0..bada74193b 100644
--- a/datafusion/core/src/physical_optimizer/sort_enforcement.rs
+++ b/datafusion/core/src/physical_optimizer/sort_enforcement.rs
@@ -2469,6 +2469,7 @@ mod tests {
                 AggregateMode::Final,
                 PhysicalGroupBy::default(),
                 vec![],
+                vec![],
                 input,
                 schema,
             )
diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs 
b/datafusion/core/src/physical_plan/aggregates/mod.rs
index ade0fa0066..3cc8fd5d7d 100644
--- a/datafusion/core/src/physical_plan/aggregates/mod.rs
+++ b/datafusion/core/src/physical_plan/aggregates/mod.rs
@@ -171,6 +171,8 @@ pub struct AggregateExec {
     pub(crate) group_by: PhysicalGroupBy,
     /// Aggregate expressions
     pub(crate) aggr_expr: Vec<Arc<dyn AggregateExpr>>,
+    /// FILTER (WHERE clause) expression for each aggregate expression
+    pub(crate) filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
     /// Input plan, could be a partial aggregate or the input to the aggregate
     pub(crate) input: Arc<dyn ExecutionPlan>,
     /// Schema after the aggregate is applied
@@ -192,6 +194,7 @@ impl AggregateExec {
         mode: AggregateMode,
         group_by: PhysicalGroupBy,
         aggr_expr: Vec<Arc<dyn AggregateExpr>>,
+        filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
         input: Arc<dyn ExecutionPlan>,
         input_schema: SchemaRef,
     ) -> Result<Self> {
@@ -221,6 +224,7 @@ impl AggregateExec {
             mode,
             group_by,
             aggr_expr,
+            filter_expr,
             input,
             schema,
             input_schema,
@@ -258,6 +262,11 @@ impl AggregateExec {
         &self.aggr_expr
     }
 
+    /// FILTER (WHERE clause) expression for each aggregate expression
+    pub fn filter_expr(&self) -> &[Option<Arc<dyn PhysicalExpr>>] {
+        &self.filter_expr
+    }
+
     /// Input plan
     pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
         &self.input
@@ -281,6 +290,7 @@ impl AggregateExec {
                 self.mode,
                 self.schema.clone(),
                 self.aggr_expr.clone(),
+                self.filter_expr.clone(),
                 input,
                 baseline_metrics,
                 context,
@@ -293,6 +303,7 @@ impl AggregateExec {
                     self.schema.clone(),
                     self.group_by.clone(),
                     self.aggr_expr.clone(),
+                    self.filter_expr.clone(),
                     input,
                     baseline_metrics,
                     batch_size,
@@ -391,6 +402,7 @@ impl ExecutionPlan for AggregateExec {
             self.mode,
             self.group_by.clone(),
             self.aggr_expr.clone(),
+            self.filter_expr.clone(),
             children[0].clone(),
             self.input_schema.clone(),
         )?))
@@ -703,6 +715,20 @@ fn evaluate_many(
         .collect::<Result<Vec<_>>>()
 }
 
+fn evaluate_optional(
+    expr: &[Option<Arc<dyn PhysicalExpr>>],
+    batch: &RecordBatch,
+) -> Result<Vec<Option<ArrayRef>>> {
+    expr.iter()
+        .map(|expr| {
+            expr.as_ref()
+                .map(|expr| expr.evaluate(batch))
+                .transpose()
+                .map(|r| r.map(|v| v.into_array(batch.num_rows())))
+        })
+        .collect::<Result<Vec<_>>>()
+}
+
 fn evaluate_group_by(
     group_by: &PhysicalGroupBy,
     batch: &RecordBatch,
@@ -839,6 +865,7 @@ mod tests {
             AggregateMode::Partial,
             grouping_set.clone(),
             aggregates.clone(),
+            vec![None],
             input,
             input_schema.clone(),
         )?);
@@ -881,6 +908,7 @@ mod tests {
             AggregateMode::Final,
             final_grouping_set,
             aggregates,
+            vec![None],
             merge,
             input_schema,
         )?);
@@ -944,6 +972,7 @@ mod tests {
             AggregateMode::Partial,
             grouping_set.clone(),
             aggregates.clone(),
+            vec![None],
             input,
             input_schema.clone(),
         )?);
@@ -976,6 +1005,7 @@ mod tests {
             AggregateMode::Final,
             final_grouping_set,
             aggregates,
+            vec![None],
             merge,
             input_schema,
         )?);
@@ -1191,6 +1221,7 @@ mod tests {
                 AggregateMode::Partial,
                 groups,
                 aggregates,
+                vec![None; 3],
                 input.clone(),
                 input_schema.clone(),
             )?);
@@ -1246,6 +1277,7 @@ mod tests {
             AggregateMode::Partial,
             groups.clone(),
             aggregates.clone(),
+            vec![None],
             blocking_exec,
             schema,
         )?);
@@ -1284,6 +1316,7 @@ mod tests {
             AggregateMode::Partial,
             groups,
             aggregates.clone(),
+            vec![None],
             blocking_exec,
             schema,
         )?);
diff --git a/datafusion/core/src/physical_plan/aggregates/no_grouping.rs 
b/datafusion/core/src/physical_plan/aggregates/no_grouping.rs
index c13f005b03..efeae8716d 100644
--- a/datafusion/core/src/physical_plan/aggregates/no_grouping.rs
+++ b/datafusion/core/src/physical_plan/aggregates/no_grouping.rs
@@ -29,10 +29,12 @@ use arrow::record_batch::RecordBatch;
 use datafusion_common::Result;
 use datafusion_physical_expr::{AggregateExpr, PhysicalExpr};
 use futures::stream::BoxStream;
+use std::borrow::Cow;
 use std::sync::Arc;
 use std::task::{Context, Poll};
 
 use crate::execution::memory_pool::{MemoryConsumer, MemoryReservation};
+use crate::physical_plan::filter::batch_filter;
 use futures::stream::{Stream, StreamExt};
 
 /// stream struct for aggregation without grouping columns
@@ -52,23 +54,32 @@ struct AggregateStreamInner {
     input: SendableRecordBatchStream,
     baseline_metrics: BaselineMetrics,
     aggregate_expressions: Vec<Vec<Arc<dyn PhysicalExpr>>>,
+    filter_expressions: Vec<Option<Arc<dyn PhysicalExpr>>>,
     accumulators: Vec<AccumulatorItem>,
     reservation: MemoryReservation,
     finished: bool,
 }
 
 impl AggregateStream {
+    #[allow(clippy::too_many_arguments)]
     /// Create a new AggregateStream
     pub fn new(
         mode: AggregateMode,
         schema: SchemaRef,
         aggr_expr: Vec<Arc<dyn AggregateExpr>>,
+        filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
         input: SendableRecordBatchStream,
         baseline_metrics: BaselineMetrics,
         context: Arc<TaskContext>,
         partition: usize,
     ) -> Result<Self> {
         let aggregate_expressions = aggregate_expressions(&aggr_expr, &mode, 
0)?;
+        let filter_expressions = match mode {
+            AggregateMode::Partial => filter_expr,
+            AggregateMode::Final | AggregateMode::FinalPartitioned => {
+                vec![None; aggr_expr.len()]
+            }
+        };
         let accumulators = create_accumulators(&aggr_expr)?;
 
         let reservation = 
MemoryConsumer::new(format!("AggregateStream[{partition}]"))
@@ -80,6 +91,7 @@ impl AggregateStream {
             input,
             baseline_metrics,
             aggregate_expressions,
+            filter_expressions,
             accumulators,
             reservation,
             finished: false,
@@ -97,9 +109,10 @@ impl AggregateStream {
                         let timer = elapsed_compute.timer();
                         let result = aggregate_batch(
                             &this.mode,
-                            &batch,
+                            batch,
                             &mut this.accumulators,
                             &this.aggregate_expressions,
+                            &this.filter_expressions,
                         );
 
                         timer.done();
@@ -169,29 +182,37 @@ impl RecordBatchStream for AggregateStream {
 /// TODO: Make this a member function
 fn aggregate_batch(
     mode: &AggregateMode,
-    batch: &RecordBatch,
+    batch: RecordBatch,
     accumulators: &mut [AccumulatorItem],
     expressions: &[Vec<Arc<dyn PhysicalExpr>>],
+    filters: &[Option<Arc<dyn PhysicalExpr>>],
 ) -> Result<usize> {
     let mut allocated = 0usize;
 
     // 1.1 iterate accumulators and respective expressions together
-    // 1.2 evaluate expressions
-    // 1.3 update / merge accumulators with the expressions' values
+    // 1.2 filter the batch if necessary
+    // 1.3 evaluate expressions
+    // 1.4 update / merge accumulators with the expressions' values
 
     // 1.1
     accumulators
         .iter_mut()
         .zip(expressions)
-        .try_for_each(|(accum, expr)| {
+        .zip(filters)
+        .try_for_each(|((accum, expr), filter)| {
             // 1.2
+            let batch = match filter {
+                Some(filter) => Cow::Owned(batch_filter(&batch, filter)?),
+                None => Cow::Borrowed(&batch),
+            };
+            // 1.3
             let values = &expr
                 .iter()
-                .map(|e| e.evaluate(batch))
+                .map(|e| e.evaluate(&batch))
                 .map(|r| r.map(|v| v.into_array(batch.num_rows())))
                 .collect::<Result<Vec<_>>>()?;
 
-            // 1.3
+            // 1.4
             let size_pre = accum.size();
             let res = match mode {
                 AggregateMode::Partial => accum.update_batch(values),
diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash.rs 
b/datafusion/core/src/physical_plan/aggregates/row_hash.rs
index 42ba9f8cb3..3cc2442543 100644
--- a/datafusion/core/src/physical_plan/aggregates/row_hash.rs
+++ b/datafusion/core/src/physical_plan/aggregates/row_hash.rs
@@ -32,8 +32,8 @@ use futures::stream::{Stream, StreamExt};
 use crate::execution::context::TaskContext;
 use crate::execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt};
 use crate::physical_plan::aggregates::{
-    evaluate_group_by, evaluate_many, group_schema, AccumulatorItem, 
AggregateMode,
-    PhysicalGroupBy, RowAccumulatorItem,
+    evaluate_group_by, evaluate_many, evaluate_optional, group_schema, 
AccumulatorItem,
+    AggregateMode, PhysicalGroupBy, RowAccumulatorItem,
 };
 use crate::physical_plan::metrics::{BaselineMetrics, RecordOutput};
 use crate::physical_plan::{aggregates, AggregateExpr, PhysicalExpr};
@@ -41,9 +41,10 @@ use crate::physical_plan::{RecordBatchStream, 
SendableRecordBatchStream};
 
 use crate::execution::memory_pool::{MemoryConsumer, MemoryReservation};
 use arrow::array::{new_null_array, Array, ArrayRef, PrimitiveArray, 
UInt32Builder};
-use arrow::compute::cast;
+use arrow::compute::{cast, filter};
 use arrow::datatypes::{DataType, Schema, UInt32Type};
-use arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
+use arrow::{compute, datatypes::SchemaRef, record_batch::RecordBatch};
+use datafusion_common::cast::as_boolean_array;
 use datafusion_common::utils::get_arrayref_at_indices;
 use datafusion_common::{Result, ScalarValue};
 use datafusion_expr::Accumulator;
@@ -73,21 +74,26 @@ pub(crate) struct GroupedHashAggregateStream {
     schema: SchemaRef,
     input: SendableRecordBatchStream,
     mode: AggregateMode,
-    exec_state: ExecutionState,
+
     normal_aggr_expr: Vec<Arc<dyn AggregateExpr>>,
-    row_aggr_state: RowAggregationState,
     /// Aggregate expressions not supporting row accumulation
     normal_aggregate_expressions: Vec<Vec<Arc<dyn PhysicalExpr>>>,
+    /// Filter expression for each normal aggregate expression
+    normal_filter_expressions: Vec<Option<Arc<dyn PhysicalExpr>>>,
+
     /// Aggregate expressions supporting row accumulation
     row_aggregate_expressions: Vec<Vec<Arc<dyn PhysicalExpr>>>,
-
-    group_by: PhysicalGroupBy,
+    /// Filter expression for each row aggregate expression
+    row_filter_expressions: Vec<Option<Arc<dyn PhysicalExpr>>>,
     row_accumulators: Vec<RowAccumulatorItem>,
-
     row_converter: RowConverter,
     row_aggr_schema: SchemaRef,
     row_aggr_layout: Arc<RowLayout>,
 
+    group_by: PhysicalGroupBy,
+
+    aggr_state: AggregationState,
+    exec_state: ExecutionState,
     baseline_metrics: BaselineMetrics,
     random_state: RandomState,
     /// size to be used for resulting RecordBatches
@@ -125,6 +131,7 @@ impl GroupedHashAggregateStream {
         schema: SchemaRef,
         group_by: PhysicalGroupBy,
         aggr_expr: Vec<Arc<dyn AggregateExpr>>,
+        filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
         input: SendableRecordBatchStream,
         baseline_metrics: BaselineMetrics,
         batch_size: usize,
@@ -137,15 +144,26 @@ impl GroupedHashAggregateStream {
         let mut row_aggr_expr = vec![];
         let mut row_agg_indices = vec![];
         let mut row_aggregate_expressions = vec![];
+        let mut row_filter_expressions = vec![];
         let mut normal_aggr_expr = vec![];
         let mut normal_agg_indices = vec![];
         let mut normal_aggregate_expressions = vec![];
+        let mut normal_filter_expressions = vec![];
         // The expressions to evaluate the batch, one vec of expressions per 
aggregation.
         // Assuming create_schema() always puts group columns in front of 
aggregation columns, we set
         // col_idx_base to the group expression count.
         let all_aggregate_expressions =
             aggregates::aggregate_expressions(&aggr_expr, &mode, start_idx)?;
-        for (expr, others) in 
aggr_expr.iter().zip(all_aggregate_expressions.into_iter())
+        let filter_expressions = match mode {
+            AggregateMode::Partial => filter_expr,
+            AggregateMode::Final | AggregateMode::FinalPartitioned => {
+                vec![None; aggr_expr.len()]
+            }
+        };
+        for ((expr, others), filter) in aggr_expr
+            .iter()
+            .zip(all_aggregate_expressions.into_iter())
+            .zip(filter_expressions.into_iter())
         {
             let n_fields = match mode {
                 // In partial aggregation, we keep additional fields in order 
to successfully
@@ -160,10 +178,12 @@ impl GroupedHashAggregateStream {
             };
             if expr.row_accumulator_supported() {
                 row_aggregate_expressions.push(others);
+                row_filter_expressions.push(filter.clone());
                 row_agg_indices.push(aggr_range);
                 row_aggr_expr.push(expr.clone());
             } else {
                 normal_aggregate_expressions.push(others);
+                normal_filter_expressions.push(filter.clone());
                 normal_agg_indices.push(aggr_range);
                 normal_aggr_expr.push(expr.clone());
             }
@@ -187,7 +207,7 @@ impl GroupedHashAggregateStream {
             Arc::new(RowLayout::new(&row_aggr_schema, RowType::WordAligned));
 
         let name = format!("GroupedHashAggregateStream[{partition}]");
-        let row_aggr_state = RowAggregationState {
+        let aggr_state = AggregationState {
             reservation: 
MemoryConsumer::new(name).register(context.memory_pool()),
             map: RawTable::with_capacity(0),
             group_states: Vec::with_capacity(0),
@@ -199,19 +219,21 @@ impl GroupedHashAggregateStream {
 
         Ok(GroupedHashAggregateStream {
             schema: Arc::clone(&schema),
-            mode,
-            exec_state,
             input,
-            group_by,
+            mode,
             normal_aggr_expr,
+            normal_aggregate_expressions,
+            normal_filter_expressions,
+            row_aggregate_expressions,
+            row_filter_expressions,
             row_accumulators,
             row_converter,
             row_aggr_schema,
             row_aggr_layout,
+            group_by,
+            aggr_state,
+            exec_state,
             baseline_metrics,
-            normal_aggregate_expressions,
-            row_aggregate_expressions,
-            row_aggr_state,
             random_state: Default::default(),
             batch_size,
             row_group_skip_position: 0,
@@ -243,7 +265,7 @@ impl Stream for GroupedHashAggregateStream {
                             // This happens AFTER we actually used the memory, 
but simplifies the whole accounting and we are OK with
                             // overshooting a bit. Also this means we either 
store the whole record batch or not.
                             let result = result.and_then(|allocated| {
-                                
self.row_aggr_state.reservation.try_grow(allocated)
+                                self.aggr_state.reservation.try_grow(allocated)
                             });
 
                             if let Err(e) = result {
@@ -312,25 +334,23 @@ impl GroupedHashAggregateStream {
         let mut batch_hashes = vec![0; n_rows];
         create_hashes(group_values, &self.random_state, &mut batch_hashes)?;
 
-        let RowAggregationState {
-            map: row_map,
-            group_states: row_group_states,
-            ..
-        } = &mut self.row_aggr_state;
+        let AggregationState {
+            map, group_states, ..
+        } = &mut self.aggr_state;
 
         for (row, hash) in batch_hashes.into_iter().enumerate() {
-            let entry = row_map.get_mut(hash, |(_hash, group_idx)| {
+            let entry = map.get_mut(hash, |(_hash, group_idx)| {
                 // verify that a group that we are inserting with hash is
                 // actually the same key value as the group in
                 // existing_idx  (aka group_values @ row)
-                let group_state = &row_group_states[*group_idx];
+                let group_state = &group_states[*group_idx];
                 group_rows.row(row) == group_state.group_by_values.row()
             });
 
             match entry {
                 // Existing entry for this group value
                 Some((_hash, group_idx)) => {
-                    let group_state = &mut row_group_states[*group_idx];
+                    let group_state = &mut group_states[*group_idx];
 
                     // 1.3
                     if group_state.indices.is_empty() {
@@ -344,7 +364,7 @@ impl GroupedHashAggregateStream {
                     let accumulator_set =
                         
aggregates::create_accumulators(&self.normal_aggr_expr)?;
                     // Add new entry to group_states and save newly created 
index
-                    let group_state = RowGroupState {
+                    let group_state = GroupState {
                         group_by_values: group_rows.row(row).owned(),
                         aggregation_buffer: vec![
                             0;
@@ -353,9 +373,9 @@ impl GroupedHashAggregateStream {
                         accumulator_set,
                         indices: vec![row as u32], // 1.3
                     };
-                    let group_idx = row_group_states.len();
+                    let group_idx = group_states.len();
 
-                    // NOTE: do NOT include the `RowGroupState` struct size in 
here because this is captured by
+                    // NOTE: do NOT include the `GroupState` struct size in 
here because this is captured by
                     // `group_states` (see allocation down below)
                     *allocated += (std::mem::size_of::<u8>()
                         * group_state.group_by_values.as_ref().len())
@@ -373,13 +393,13 @@ impl GroupedHashAggregateStream {
                             .sum::<usize>();
 
                     // for hasher function, use precomputed hash value
-                    row_map.insert_accounted(
+                    map.insert_accounted(
                         (hash, group_idx),
                         |(hash, _group_index)| *hash,
                         allocated,
                     );
 
-                    row_group_states.push_accounted(group_state, allocated);
+                    group_states.push_accounted(group_state, allocated);
 
                     groups_with_rows.push(group_idx);
                 }
@@ -389,12 +409,15 @@ impl GroupedHashAggregateStream {
     }
 
     // Update the accumulator results, according to row_aggr_state.
+    #[allow(clippy::too_many_arguments)]
     fn update_accumulators(
         &mut self,
         groups_with_rows: &[usize],
         offsets: &[usize],
         row_values: &[Vec<ArrayRef>],
         normal_values: &[Vec<ArrayRef>],
+        row_filter_values: &[Option<ArrayRef>],
+        normal_filter_values: &[Option<ArrayRef>],
         allocated: &mut usize,
     ) -> Result<()> {
         // 2.1 for each key in this batch
@@ -406,24 +429,19 @@ impl GroupedHashAggregateStream {
             .iter()
             .zip(offsets.windows(2))
             .try_for_each(|(group_idx, offsets)| {
-                let group_state = &mut 
self.row_aggr_state.group_states[*group_idx];
+                let group_state = &mut 
self.aggr_state.group_states[*group_idx];
                 // 2.2
+                // Process row accumulators
                 self.row_accumulators
                     .iter_mut()
                     .zip(row_values.iter())
-                    .map(|(accumulator, aggr_array)| {
-                        (
-                            accumulator,
-                            aggr_array
-                                .iter()
-                                .map(|array| {
-                                    // 2.3
-                                    array.slice(offsets[0], offsets[1] - 
offsets[0])
-                                })
-                                .collect::<Vec<ArrayRef>>(),
-                        )
-                    })
-                    .try_for_each(|(accumulator, values)| {
+                    .zip(row_filter_values.iter())
+                    .try_for_each(|((accumulator, aggr_array), filter_opt)| {
+                        let values = slice_and_maybe_filter(
+                            aggr_array,
+                            filter_opt.as_ref(),
+                            offsets,
+                        )?;
                         let mut state_accessor =
                             
RowAccessor::new_from_layout(self.row_aggr_layout.clone());
                         state_accessor
@@ -437,27 +455,19 @@ impl GroupedHashAggregateStream {
                                 accumulator.merge_batch(&values, &mut 
state_accessor)
                             }
                         }
-                    })
-                    // 2.5
-                    .and(Ok(()))?;
+                    })?;
                 // normal accumulators
                 group_state
                     .accumulator_set
                     .iter_mut()
                     .zip(normal_values.iter())
-                    .map(|(accumulator, aggr_array)| {
-                        (
-                            accumulator,
-                            aggr_array
-                                .iter()
-                                .map(|array| {
-                                    // 2.3
-                                    array.slice(offsets[0], offsets[1] - 
offsets[0])
-                                })
-                                .collect::<Vec<ArrayRef>>(),
-                        )
-                    })
-                    .try_for_each(|(accumulator, values)| {
+                    .zip(normal_filter_values.iter())
+                    .try_for_each(|((accumulator, aggr_array), filter_opt)| {
+                        let values = slice_and_maybe_filter(
+                            aggr_array,
+                            filter_opt.as_ref(),
+                            offsets,
+                        )?;
                         let size_pre = accumulator.size();
                         let res = match self.mode {
                             AggregateMode::Partial => 
accumulator.update_batch(&values),
@@ -496,6 +506,9 @@ impl GroupedHashAggregateStream {
             evaluate_many(&self.row_aggregate_expressions, &batch)?;
         let normal_aggr_input_values =
             evaluate_many(&self.normal_aggregate_expressions, &batch)?;
+        let row_filter_values = 
evaluate_optional(&self.row_filter_expressions, &batch)?;
+        let normal_filter_values =
+            evaluate_optional(&self.normal_filter_expressions, &batch)?;
 
         let row_converter_size_pre = self.row_converter.size();
         for group_values in &group_by_values {
@@ -507,7 +520,7 @@ impl GroupedHashAggregateStream {
             let mut offsets = vec![0];
             let mut offset_so_far = 0;
             for &group_idx in groups_with_rows.iter() {
-                let indices = 
&self.row_aggr_state.group_states[group_idx].indices;
+                let indices = &self.aggr_state.group_states[group_idx].indices;
                 batch_indices.append_slice(indices);
                 offset_so_far += indices.len();
                 offsets.push(offset_so_far);
@@ -517,11 +530,17 @@ impl GroupedHashAggregateStream {
             let row_values = get_at_indices(&row_aggr_input_values, 
&batch_indices)?;
             let normal_values =
                 get_at_indices(&normal_aggr_input_values, &batch_indices)?;
+            let row_filter_values =
+                get_optional_filters(&row_filter_values, &batch_indices);
+            let normal_filter_values =
+                get_optional_filters(&normal_filter_values, &batch_indices);
             self.update_accumulators(
                 &groups_with_rows,
                 &offsets,
                 &row_values,
                 &normal_values,
+                &row_filter_values,
+                &normal_filter_values,
                 &mut allocated,
             )?;
         }
@@ -535,7 +554,7 @@ impl GroupedHashAggregateStream {
 
 /// The state that is built for each output group.
 #[derive(Debug)]
-pub struct RowGroupState {
+pub struct GroupState {
     /// The actual group by values, stored sequentially
     group_by_values: OwnedRow,
 
@@ -551,7 +570,7 @@ pub struct RowGroupState {
 }
 
 /// The state of all the groups
-pub struct RowAggregationState {
+pub struct AggregationState {
     pub reservation: MemoryReservation,
 
     /// Logically maps group values to an index in `group_states`
@@ -564,10 +583,10 @@ pub struct RowAggregationState {
     pub map: RawTable<(u64, usize)>,
 
     /// State for each group
-    pub group_states: Vec<RowGroupState>,
+    pub group_states: Vec<GroupState>,
 }
 
-impl std::fmt::Debug for RowAggregationState {
+impl std::fmt::Debug for AggregationState {
     fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
         // hashes are not store inline, so could only get values
         let map_string = "RawTable";
@@ -582,19 +601,19 @@ impl GroupedHashAggregateStream {
     /// Create a RecordBatch with all group keys and accumulator' states or 
values.
     fn create_batch_from_map(&mut self) -> Result<Option<RecordBatch>> {
         let skip_items = self.row_group_skip_position;
-        if skip_items > self.row_aggr_state.group_states.len() {
+        if skip_items > self.aggr_state.group_states.len() {
             return Ok(None);
         }
-        if self.row_aggr_state.group_states.is_empty() {
+        if self.aggr_state.group_states.is_empty() {
             let schema = self.schema.clone();
             return Ok(Some(RecordBatch::new_empty(schema)));
         }
 
         let end_idx = min(
             skip_items + self.batch_size,
-            self.row_aggr_state.group_states.len(),
+            self.aggr_state.group_states.len(),
         );
-        let group_state_chunk = 
&self.row_aggr_state.group_states[skip_items..end_idx];
+        let group_state_chunk = 
&self.aggr_state.group_states[skip_items..end_idx];
 
         if group_state_chunk.is_empty() {
             let schema = self.schema.clone();
@@ -648,8 +667,8 @@ impl GroupedHashAggregateStream {
             for (field_idx, field) in 
output_fields[start..end].iter().enumerate() {
                 let current = match self.mode {
                     AggregateMode::Partial => ScalarValue::iter_to_array(
-                        group_state_chunk.iter().map(|row_group_state| {
-                            row_group_state.accumulator_set[idx]
+                        group_state_chunk.iter().map(|group_state| {
+                            group_state.accumulator_set[idx]
                                 .state()
                                 .map(|v| v[field_idx].clone())
                                 .expect("Unexpected accumulator state in hash 
aggregate")
@@ -657,8 +676,8 @@ impl GroupedHashAggregateStream {
                     ),
                     AggregateMode::Final | AggregateMode::FinalPartitioned => {
                         
ScalarValue::iter_to_array(group_state_chunk.iter().map(
-                            |row_group_state| {
-                                
row_group_state.accumulator_set[idx].evaluate().expect(
+                            |group_state| {
+                                
group_state.accumulator_set[idx].evaluate().expect(
                                     "Unexpected accumulator state in hash 
aggregate",
                                 )
                             },
@@ -726,3 +745,47 @@ fn get_at_indices(
         .map(|array| get_arrayref_at_indices(array, batch_indices))
         .collect()
 }
+
+fn get_optional_filters(
+    original_values: &[Option<Arc<dyn Array>>],
+    batch_indices: &PrimitiveArray<UInt32Type>,
+) -> Vec<Option<Arc<dyn Array>>> {
+    original_values
+        .iter()
+        .map(|array| {
+            array.as_ref().map(|array| {
+                compute::take(
+                    array.as_ref(),
+                    batch_indices,
+                    None, // None: no index check
+                )
+                .unwrap()
+            })
+        })
+        .collect()
+}
+
+fn slice_and_maybe_filter(
+    aggr_array: &[ArrayRef],
+    filter_opt: Option<&Arc<dyn Array>>,
+    offsets: &[usize],
+) -> Result<Vec<ArrayRef>> {
+    let sliced_arrays: Vec<ArrayRef> = aggr_array
+        .iter()
+        .map(|array| array.slice(offsets[0], offsets[1] - offsets[0]))
+        .collect();
+
+    let filtered_arrays = match filter_opt.as_ref() {
+        Some(f) => {
+            let sliced = f.slice(offsets[0], offsets[1] - offsets[0]);
+            let filter_array = as_boolean_array(&sliced)?;
+
+            sliced_arrays
+                .iter()
+                .map(|array| filter(array, filter_array).unwrap())
+                .collect::<Vec<ArrayRef>>()
+        }
+        None => sliced_arrays,
+    };
+    Ok(filtered_arrays)
+}
diff --git a/datafusion/core/src/physical_plan/filter.rs 
b/datafusion/core/src/physical_plan/filter.rs
index a72aa69d07..494d3fc869 100644
--- a/datafusion/core/src/physical_plan/filter.rs
+++ b/datafusion/core/src/physical_plan/filter.rs
@@ -235,7 +235,7 @@ struct FilterExecStream {
     baseline_metrics: BaselineMetrics,
 }
 
-fn batch_filter(
+pub(crate) fn batch_filter(
     batch: &RecordBatch,
     predicate: &Arc<dyn PhysicalExpr>,
 ) -> Result<RecordBatch> {
diff --git a/datafusion/core/src/physical_plan/planner.rs 
b/datafusion/core/src/physical_plan/planner.rs
index 2064357f7d..8ee32b8d04 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -627,10 +627,10 @@ impl DefaultPhysicalPlanner {
                         &physical_input_schema,
                         session_state)?;
 
-                    let aggregates = aggr_expr
+                    let agg_filter = aggr_expr
                         .iter()
                         .map(|e| {
-                            create_aggregate_expr(
+                            create_aggregate_expr_and_maybe_filter(
                                 e,
                                 logical_input_schema,
                                 &physical_input_schema,
@@ -638,11 +638,13 @@ impl DefaultPhysicalPlanner {
                             )
                         })
                         .collect::<Result<Vec<_>>>()?;
+                    let (aggregates, filters): (Vec<_>, Vec<_>) = 
agg_filter.into_iter().unzip();
 
                     let initial_aggr = Arc::new(AggregateExec::try_new(
                         AggregateMode::Partial,
                         groups.clone(),
                         aggregates.clone(),
+                        filters.clone(),
                         input_exec,
                         physical_input_schema.clone(),
                     )?);
@@ -678,6 +680,7 @@ impl DefaultPhysicalPlanner {
                         next_partition_mode,
                         final_grouping_set,
                         aggregates,
+                        filters,
                         initial_aggr,
                         physical_input_schema.clone(),
                     )?))
@@ -1609,20 +1612,23 @@ pub fn create_window_expr(
     )
 }
 
+type AggregateExprWithOptionalFilter =
+    (Arc<dyn AggregateExpr>, Option<Arc<dyn PhysicalExpr>>);
+
 /// Create an aggregate expression with a name from a logical expression
-pub fn create_aggregate_expr_with_name(
+pub fn create_aggregate_expr_with_name_and_maybe_filter(
     e: &Expr,
     name: impl Into<String>,
     logical_input_schema: &DFSchema,
     physical_input_schema: &Schema,
     execution_props: &ExecutionProps,
-) -> Result<Arc<dyn AggregateExpr>> {
+) -> Result<AggregateExprWithOptionalFilter> {
     match e {
         Expr::AggregateFunction(AggregateFunction {
             fun,
             distinct,
             args,
-            ..
+            filter,
         }) => {
             let args = args
                 .iter()
@@ -1635,15 +1641,25 @@ pub fn create_aggregate_expr_with_name(
                     )
                 })
                 .collect::<Result<Vec<_>>>()?;
-            aggregates::create_aggregate_expr(
+            let filter = match filter {
+                Some(e) => Some(create_physical_expr(
+                    e,
+                    logical_input_schema,
+                    physical_input_schema,
+                    execution_props,
+                )?),
+                None => None,
+            };
+            let agg_expr = aggregates::create_aggregate_expr(
                 fun,
                 *distinct,
                 &args,
                 physical_input_schema,
                 name,
-            )
+            );
+            Ok((agg_expr?, filter))
         }
-        Expr::AggregateUDF { fun, args, .. } => {
+        Expr::AggregateUDF { fun, args, filter } => {
             let args = args
                 .iter()
                 .map(|e| {
@@ -1656,7 +1672,19 @@ pub fn create_aggregate_expr_with_name(
                 })
                 .collect::<Result<Vec<_>>>()?;
 
-            udaf::create_aggregate_expr(fun, &args, physical_input_schema, 
name)
+            let filter = match filter {
+                Some(e) => Some(create_physical_expr(
+                    e,
+                    logical_input_schema,
+                    physical_input_schema,
+                    execution_props,
+                )?),
+                None => None,
+            };
+
+            let agg_expr =
+                udaf::create_aggregate_expr(fun, &args, physical_input_schema, 
name);
+            Ok((agg_expr?, filter))
         }
         other => Err(DataFusionError::Internal(format!(
             "Invalid aggregate expression '{other:?}'"
@@ -1665,19 +1693,19 @@ pub fn create_aggregate_expr_with_name(
 }
 
 /// Create an aggregate expression from a logical expression or an alias
-pub fn create_aggregate_expr(
+pub fn create_aggregate_expr_and_maybe_filter(
     e: &Expr,
     logical_input_schema: &DFSchema,
     physical_input_schema: &Schema,
     execution_props: &ExecutionProps,
-) -> Result<Arc<dyn AggregateExpr>> {
+) -> Result<AggregateExprWithOptionalFilter> {
     // unpack (nested) aliased logical expressions, e.g. "sum(col) as total"
     let (name, e) = match e {
         Expr::Alias(sub_expr, alias) => (alias.clone(), sub_expr.as_ref()),
         _ => (physical_name(e)?, e),
     };
 
-    create_aggregate_expr_with_name(
+    create_aggregate_expr_with_name_and_maybe_filter(
         e,
         name,
         logical_input_schema,
@@ -1788,7 +1816,10 @@ impl DefaultPhysicalPlanner {
             "Input physical plan:\n{}\n",
             displayable(plan.as_ref()).indent()
         );
-        trace!("Detailed input physical plan:\n{:?}", plan);
+        trace!(
+            "Detailed input physical plan:\n{}",
+            displayable(plan.as_ref()).indent()
+        );
 
         let mut new_plan = plan;
         for optimizer in optimizers {
diff --git a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt 
b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
index b049e4b16c..10368341d8 100644
--- a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt
@@ -1538,3 +1538,83 @@ query RT
 select avg(c1), arrow_typeof(avg(c1)) from d_table
 ----
 5 Decimal128(14, 7)
+
+# Use PostgresSQL dialect
+statement ok
+set datafusion.sql_parser.dialect = 'Postgres';
+
+# Creating the table
+statement ok
+CREATE TABLE test_table (c1 INT, c2 INT, c3 INT)
+
+# Inserting data
+statement ok
+INSERT INTO test_table VALUES (1, 10, 50), (1, 20, 60), (2, 10, 70), (2, 20, 
80), (3, 10, NULL)
+
+# query_group_by_with_filter
+query II rowsort
+SELECT c1, SUM(c2) FILTER (WHERE c2 >= 20) as result FROM test_table GROUP BY 
c1
+----
+1 20
+2 20
+3 NULL
+
+# query_group_by_avg_with_filter
+query IR rowsort
+SELECT c1, AVG(c2) FILTER (WHERE c2 >= 20) AS avg_c2 FROM test_table GROUP BY 
c1
+----
+1 20
+2 20
+3 NULL
+
+# query_group_by_with_multiple_filters
+query IIR rowsort
+SELECT c1, SUM(c2) FILTER (WHERE c2 >= 20) AS sum_c2, AVG(c3) FILTER (WHERE c3 
<= 70) AS avg_c3 FROM test_table GROUP BY c1
+----
+1 20 55
+2 20 70
+3 NULL NULL
+
+# query_group_by_distinct_with_filter
+query II rowsort
+SELECT c1, COUNT(DISTINCT c2) FILTER (WHERE c2 >= 20) AS distinct_c2_count 
FROM test_table GROUP BY c1
+----
+1 1
+2 1
+3 0
+
+# query_without_group_by_with_filter
+query I rowsort
+SELECT SUM(c2) FILTER (WHERE c2 >= 20) AS sum_c2 FROM test_table
+----
+40
+
+# count_without_group_by_with_filter
+query I rowsort
+SELECT COUNT(c2) FILTER (WHERE c2 >= 20) AS count_c2 FROM test_table
+----
+2
+
+# query_with_and_without_filter
+query III rowsort
+SELECT c1, SUM(c2) FILTER (WHERE c2 >= 20) as result, SUM(c2) as 
result_no_filter FROM test_table GROUP BY c1;
+----
+1 20 30
+2 20 30
+3 NULL 10
+
+# query_filter_on_different_column_than_aggregate
+query I rowsort
+select sum(c1) FILTER (WHERE c2 < 30) from test_table;
+----
+9
+
+# query_test_empty_filter
+query I rowsort
+SELECT SUM(c2) FILTER (WHERE c2 >= 20000000) AS sum_c2 FROM test_table;
+----
+NULL
+
+# Restore the default dialect
+statement ok
+set datafusion.sql_parser.dialect = 'Generic';
diff --git 
a/datafusion/core/tests/sqllogictests/test_files/information_schema.slt 
b/datafusion/core/tests/sqllogictests/test_files/information_schema.slt
index 3adf5585d7..80187564f9 100644
--- a/datafusion/core/tests/sqllogictests/test_files/information_schema.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/information_schema.slt
@@ -163,6 +163,7 @@ datafusion.optimizer.repartition_sorts true
 datafusion.optimizer.repartition_windows true
 datafusion.optimizer.skip_failed_rules true
 datafusion.optimizer.top_down_join_key_reordering true
+datafusion.sql_parser.dialect generic
 datafusion.sql_parser.enable_ident_normalization true
 datafusion.sql_parser.parse_float_as_decimal false
 
diff --git a/datafusion/expr/src/tree_node/expr.rs 
b/datafusion/expr/src/tree_node/expr.rs
index 61a5c91fec..b0a5e31da0 100644
--- a/datafusion/expr/src/tree_node/expr.rs
+++ b/datafusion/expr/src/tree_node/expr.rs
@@ -297,7 +297,7 @@ impl TreeNode for Expr {
                 fun,
                 transform_vec(args, &mut transform)?,
                 distinct,
-                filter,
+                transform_option_box(filter, &mut transform)?,
             )),
             Expr::GroupingSet(grouping_set) => match grouping_set {
                 GroupingSet::Rollup(exprs) => 
Expr::GroupingSet(GroupingSet::Rollup(
@@ -318,7 +318,7 @@ impl TreeNode for Expr {
             Expr::AggregateUDF { args, fun, filter } => Expr::AggregateUDF {
                 args: transform_vec(args, &mut transform)?,
                 fun,
-                filter,
+                filter: transform_option_box(filter, &mut transform)?,
             },
             Expr::InList {
                 expr,
diff --git a/datafusion/optimizer/src/push_down_projection.rs 
b/datafusion/optimizer/src/push_down_projection.rs
index fd8f4c011a..97ba5a92d7 100644
--- a/datafusion/optimizer/src/push_down_projection.rs
+++ b/datafusion/optimizer/src/push_down_projection.rs
@@ -1030,7 +1030,7 @@ mod tests {
             )?
             .build()?;
 
-        let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(test.b), 
COUNT(test.b) FILTER (WHERE c > Int32(42)) AS count2]]\
+        let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(test.b), 
COUNT(test.b) FILTER (WHERE test.c > Int32(42)) AS count2]]\
         \n  TableScan: test projection=[a, b, c]";
 
         assert_optimized_plan_eq(&plan, expected)
diff --git a/datafusion/proto/proto/datafusion.proto 
b/datafusion/proto/proto/datafusion.proto
index 875a6a15cf..c4b3ac2114 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -1258,6 +1258,10 @@ message WindowAggExecNode {
   Schema input_schema = 4;
 }
 
+message MaybeFilter {
+  PhysicalExprNode expr = 1;
+}
+
 message AggregateExecNode {
   repeated PhysicalExprNode group_expr = 1;
   repeated PhysicalExprNode aggr_expr = 2;
@@ -1269,6 +1273,7 @@ message AggregateExecNode {
   Schema input_schema = 7;
   repeated PhysicalExprNode null_expr = 8;
   repeated bool groups = 9;
+  repeated MaybeFilter filter_expr = 10;
 }
 
 message GlobalLimitExecNode {
diff --git a/datafusion/proto/src/generated/pbjson.rs 
b/datafusion/proto/src/generated/pbjson.rs
index 63a8a2ed00..105591a000 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -33,6 +33,9 @@ impl serde::Serialize for AggregateExecNode {
         if !self.groups.is_empty() {
             len += 1;
         }
+        if !self.filter_expr.is_empty() {
+            len += 1;
+        }
         let mut struct_ser = 
serializer.serialize_struct("datafusion.AggregateExecNode", len)?;
         if !self.group_expr.is_empty() {
             struct_ser.serialize_field("groupExpr", &self.group_expr)?;
@@ -63,6 +66,9 @@ impl serde::Serialize for AggregateExecNode {
         if !self.groups.is_empty() {
             struct_ser.serialize_field("groups", &self.groups)?;
         }
+        if !self.filter_expr.is_empty() {
+            struct_ser.serialize_field("filterExpr", &self.filter_expr)?;
+        }
         struct_ser.end()
     }
 }
@@ -88,6 +94,8 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode {
             "null_expr",
             "nullExpr",
             "groups",
+            "filter_expr",
+            "filterExpr",
         ];
 
         #[allow(clippy::enum_variant_names)]
@@ -101,6 +109,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode {
             InputSchema,
             NullExpr,
             Groups,
+            FilterExpr,
         }
         impl<'de> serde::Deserialize<'de> for GeneratedField {
             fn deserialize<D>(deserializer: D) -> 
std::result::Result<GeneratedField, D::Error>
@@ -131,6 +140,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode {
                             "inputSchema" | "input_schema" => 
Ok(GeneratedField::InputSchema),
                             "nullExpr" | "null_expr" => 
Ok(GeneratedField::NullExpr),
                             "groups" => Ok(GeneratedField::Groups),
+                            "filterExpr" | "filter_expr" => 
Ok(GeneratedField::FilterExpr),
                             _ => Err(serde::de::Error::unknown_field(value, 
FIELDS)),
                         }
                     }
@@ -159,6 +169,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode {
                 let mut input_schema__ = None;
                 let mut null_expr__ = None;
                 let mut groups__ = None;
+                let mut filter_expr__ = None;
                 while let Some(k) = map.next_key()? {
                     match k {
                         GeneratedField::GroupExpr => {
@@ -215,6 +226,12 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode {
                             }
                             groups__ = Some(map.next_value()?);
                         }
+                        GeneratedField::FilterExpr => {
+                            if filter_expr__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("filterExpr"));
+                            }
+                            filter_expr__ = Some(map.next_value()?);
+                        }
                     }
                 }
                 Ok(AggregateExecNode {
@@ -227,6 +244,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode {
                     input_schema: input_schema__,
                     null_expr: null_expr__.unwrap_or_default(),
                     groups: groups__.unwrap_or_default(),
+                    filter_expr: filter_expr__.unwrap_or_default(),
                 })
             }
         }
@@ -11280,6 +11298,97 @@ impl<'de> serde::Deserialize<'de> for Map {
         deserializer.deserialize_struct("datafusion.Map", FIELDS, 
GeneratedVisitor)
     }
 }
+impl serde::Serialize for MaybeFilter {
+    #[allow(deprecated)]
+    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, 
S::Error>
+    where
+        S: serde::Serializer,
+    {
+        use serde::ser::SerializeStruct;
+        let mut len = 0;
+        if self.expr.is_some() {
+            len += 1;
+        }
+        let mut struct_ser = 
serializer.serialize_struct("datafusion.MaybeFilter", len)?;
+        if let Some(v) = self.expr.as_ref() {
+            struct_ser.serialize_field("expr", v)?;
+        }
+        struct_ser.end()
+    }
+}
+impl<'de> serde::Deserialize<'de> for MaybeFilter {
+    #[allow(deprecated)]
+    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+    where
+        D: serde::Deserializer<'de>,
+    {
+        const FIELDS: &[&str] = &[
+            "expr",
+        ];
+
+        #[allow(clippy::enum_variant_names)]
+        enum GeneratedField {
+            Expr,
+        }
+        impl<'de> serde::Deserialize<'de> for GeneratedField {
+            fn deserialize<D>(deserializer: D) -> 
std::result::Result<GeneratedField, D::Error>
+            where
+                D: serde::Deserializer<'de>,
+            {
+                struct GeneratedVisitor;
+
+                impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+                    type Value = GeneratedField;
+
+                    fn expecting(&self, formatter: &mut 
std::fmt::Formatter<'_>) -> std::fmt::Result {
+                        write!(formatter, "expected one of: {:?}", &FIELDS)
+                    }
+
+                    #[allow(unused_variables)]
+                    fn visit_str<E>(self, value: &str) -> 
std::result::Result<GeneratedField, E>
+                    where
+                        E: serde::de::Error,
+                    {
+                        match value {
+                            "expr" => Ok(GeneratedField::Expr),
+                            _ => Err(serde::de::Error::unknown_field(value, 
FIELDS)),
+                        }
+                    }
+                }
+                deserializer.deserialize_identifier(GeneratedVisitor)
+            }
+        }
+        struct GeneratedVisitor;
+        impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+            type Value = MaybeFilter;
+
+            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> 
std::fmt::Result {
+                formatter.write_str("struct datafusion.MaybeFilter")
+            }
+
+            fn visit_map<V>(self, mut map: V) -> 
std::result::Result<MaybeFilter, V::Error>
+                where
+                    V: serde::de::MapAccess<'de>,
+            {
+                let mut expr__ = None;
+                while let Some(k) = map.next_key()? {
+                    match k {
+                        GeneratedField::Expr => {
+                            if expr__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("expr"));
+                            }
+                            expr__ = map.next_value()?;
+                        }
+                    }
+                }
+                Ok(MaybeFilter {
+                    expr: expr__,
+                })
+            }
+        }
+        deserializer.deserialize_struct("datafusion.MaybeFilter", FIELDS, 
GeneratedVisitor)
+    }
+}
 impl serde::Serialize for NegativeNode {
     #[allow(deprecated)]
     fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, 
S::Error>
diff --git a/datafusion/proto/src/generated/prost.rs 
b/datafusion/proto/src/generated/prost.rs
index 7764fe7848..ebdd14b2f3 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1791,6 +1791,12 @@ pub struct WindowAggExecNode {
 }
 #[allow(clippy::derive_partial_eq_without_eq)]
 #[derive(Clone, PartialEq, ::prost::Message)]
+pub struct MaybeFilter {
+    #[prost(message, optional, tag = "1")]
+    pub expr: ::core::option::Option<PhysicalExprNode>,
+}
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
 pub struct AggregateExecNode {
     #[prost(message, repeated, tag = "1")]
     pub group_expr: ::prost::alloc::vec::Vec<PhysicalExprNode>,
@@ -1811,6 +1817,8 @@ pub struct AggregateExecNode {
     pub null_expr: ::prost::alloc::vec::Vec<PhysicalExprNode>,
     #[prost(bool, repeated, tag = "9")]
     pub groups: ::prost::alloc::vec::Vec<bool>,
+    #[prost(message, repeated, tag = "10")]
+    pub filter_expr: ::prost::alloc::vec::Vec<MaybeFilter>,
 }
 #[allow(clippy::derive_partial_eq_without_eq)]
 #[derive(Clone, PartialEq, ::prost::Message)]
diff --git a/datafusion/proto/src/physical_plan/mod.rs 
b/datafusion/proto/src/physical_plan/mod.rs
index 8fd57f002b..ff13bbfb8f 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -403,6 +403,18 @@ impl AsExecutionPlan for PhysicalPlanNode {
                 let physical_schema: SchemaRef =
                     SchemaRef::new((&input_schema).try_into()?);
 
+                let physical_filter_expr = hash_agg
+                    .filter_expr
+                    .iter()
+                    .map(|expr| {
+                        let x = expr
+                            .expr
+                            .as_ref()
+                            .map(|e| parse_physical_expr(e, registry, 
&physical_schema));
+                        x.transpose()
+                    })
+                    .collect::<Result<Vec<_>, _>>()?;
+
                 let physical_aggr_expr: Vec<Arc<dyn AggregateExpr>> = hash_agg
                     .aggr_expr
                     .iter()
@@ -450,6 +462,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
                     agg_mode,
                     PhysicalGroupBy::new(group_expr, null_expr, groups),
                     physical_aggr_expr,
+                    physical_filter_expr,
                     input,
                     Arc::new((&input_schema).try_into()?),
                 )?))
@@ -864,6 +877,12 @@ impl AsExecutionPlan for PhysicalPlanNode {
                 .map(|expr| expr.1.to_owned())
                 .collect();
 
+            let filter = exec
+                .filter_expr()
+                .iter()
+                .map(|expr| expr.to_owned().try_into())
+                .collect::<Result<Vec<_>>>()?;
+
             let agg = exec
                 .aggr_expr()
                 .iter()
@@ -911,6 +930,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
                         group_expr,
                         group_expr_name: group_names,
                         aggr_expr: agg,
+                        filter_expr: filter,
                         aggr_expr_name: agg_names,
                         mode: agg_mode as i32,
                         input: Some(Box::new(input)),
@@ -1391,6 +1411,7 @@ mod roundtrip_tests {
             AggregateMode::Final,
             PhysicalGroupBy::new_single(groups.clone()),
             aggregates.clone(),
+            vec![None],
             Arc::new(EmptyExec::new(false, schema.clone())),
             schema,
         )?))
@@ -1601,6 +1622,7 @@ mod roundtrip_tests {
             AggregateMode::Final,
             PhysicalGroupBy::new_single(groups),
             aggregates.clone(),
+            vec![None],
             Arc::new(EmptyExec::new(false, schema.clone())),
             schema,
         )?))
diff --git a/datafusion/proto/src/physical_plan/to_proto.rs 
b/datafusion/proto/src/physical_plan/to_proto.rs
index 9210a2d7fa..e18932575c 100644
--- a/datafusion/proto/src/physical_plan/to_proto.rs
+++ b/datafusion/proto/src/physical_plan/to_proto.rs
@@ -498,3 +498,16 @@ impl From<JoinSide> for protobuf::JoinSide {
         }
     }
 }
+
+impl TryFrom<Option<Arc<dyn PhysicalExpr>>> for protobuf::MaybeFilter {
+    type Error = DataFusionError;
+
+    fn try_from(expr: Option<Arc<dyn PhysicalExpr>>) -> Result<Self, 
Self::Error> {
+        match expr {
+            None => Ok(protobuf::MaybeFilter { expr: None }),
+            Some(expr) => Ok(protobuf::MaybeFilter {
+                expr: Some(expr.try_into()?),
+            }),
+        }
+    }
+}
diff --git a/datafusion/sql/tests/integration_test.rs 
b/datafusion/sql/tests/integration_test.rs
index 3749e65573..64ca85b72d 100644
--- a/datafusion/sql/tests/integration_test.rs
+++ b/datafusion/sql/tests/integration_test.rs
@@ -3061,8 +3061,8 @@ fn hive_aggregate_with_filter() -> Result<()> {
     let dialect = &HiveDialect {};
     let sql = "SELECT SUM(age) FILTER (WHERE age > 4) FROM person";
     let plan = logical_plan_with_dialect(sql, dialect)?;
-    let expected = "Projection: SUM(person.age) FILTER (WHERE age > Int64(4))\
-        \n  Aggregate: groupBy=[[]], aggr=[[SUM(person.age) FILTER (WHERE age 
> Int64(4))]]\
+    let expected = "Projection: SUM(person.age) FILTER (WHERE person.age > 
Int64(4))\
+        \n  Aggregate: groupBy=[[]], aggr=[[SUM(person.age) FILTER (WHERE 
person.age > Int64(4))]]\
         \n    TableScan: person"
         .to_string();
     assert_eq!(plan.display_indent().to_string(), expected);
diff --git a/docs/source/user-guide/configs.md 
b/docs/source/user-guide/configs.md
index 749a0bcb06..dc21c81942 100644
--- a/docs/source/user-guide/configs.md
+++ b/docs/source/user-guide/configs.md
@@ -73,3 +73,4 @@ Environment variables are read during `SessionConfig` 
initialisation so they mus
 | datafusion.explain.physical_plan_only                      | false      | 
When set to true, the explain statement will only print physical plans          
                                                                                
                                                                                
                                                                                
                                                                                
                 [...]
 | datafusion.sql_parser.parse_float_as_decimal               | false      | 
When set to true, SQL parser will parse float as decimal type                   
                                                                                
                                                                                
                                                                                
                                                                                
                 [...]
 | datafusion.sql_parser.enable_ident_normalization           | true       | 
When set to true, SQL parser will normalize ident (convert ident to lowercase 
when not quoted)                                                                
                                                                                
                                                                                
                                                                                
                   [...]
+| datafusion.sql_parser.dialect                              | generic    | 
Configure the SQL dialect used by DataFusion's parser; supported values 
include: Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, 
ClickHouse, BigQuery, and Ansi.                                                 
                                                                                
                                                                                
                          [...]

Reply via email to