alamb commented on code in PR #5868:
URL: https://github.com/apache/arrow-datafusion/pull/5868#discussion_r1160159239


##########
datafusion/common/src/config.rs:
##########
@@ -187,6 +187,10 @@ config_namespace! {
         /// When set to true, SQL parser will normalize ident (convert ident 
to lowercase when not quoted)
         pub enable_ident_normalization: bool, default = true
 
+        /// Configure the SQL dialect used by DataFusion's parser; supported 
values include: Generic,

Review Comment:
   👍 



##########
datafusion/core/src/execution/context.rs:
##########
@@ -1833,6 +1860,29 @@ impl From<&SessionState> for TaskContext {
     }
 }
 
+fn create_dialect_from_str(dialect_name: &str) -> Box<dyn Dialect> {

Review Comment:
   What would you think about putting this in a PR  upstream in sqlparser-rs?  
I can do so if you agree



##########
datafusion/core/src/execution/context.rs:
##########
@@ -1833,6 +1860,29 @@ impl From<&SessionState> for TaskContext {
     }
 }
 
+fn create_dialect_from_str(dialect_name: &str) -> Box<dyn Dialect> {
+    match dialect_name.to_lowercase().as_str() {
+        "generic" => Box::new(GenericDialect),
+        "mysql" => Box::new(MySqlDialect {}),
+        "postgresql" | "postgres" => Box::new(PostgreSqlDialect {}),
+        "hive" => Box::new(HiveDialect {}),
+        "sqlite" => Box::new(SQLiteDialect {}),
+        "snowflake" => Box::new(SnowflakeDialect),
+        "redshift" => Box::new(RedshiftSqlDialect {}),
+        "mssql" => Box::new(MsSqlDialect {}),
+        "clickhouse" => Box::new(ClickHouseDialect {}),
+        "bigquery" => Box::new(BigQueryDialect),
+        "ansi" => Box::new(AnsiDialect {}),
+        _ => {

Review Comment:
   I think it might be better to return an error here (rather than doing 
`println` -- for one thing when running in a server or other distributed 
context, stdout may not be connected to anything



##########
datafusion/core/src/physical_plan/aggregates/no_grouping.rs:
##########
@@ -172,26 +177,34 @@ fn aggregate_batch(
     batch: &RecordBatch,
     accumulators: &mut [AccumulatorItem],
     expressions: &[Vec<Arc<dyn PhysicalExpr>>],
+    filters: &[Option<Arc<dyn PhysicalExpr>>],
 ) -> Result<usize> {
     let mut allocated = 0usize;
 
     // 1.1 iterate accumulators and respective expressions together
-    // 1.2 evaluate expressions
-    // 1.3 update / merge accumulators with the expressions' values
+    // 1.2 filter the batch if necessary
+    // 1.3 evaluate expressions
+    // 1.4 update / merge accumulators with the expressions' values
 
     // 1.1
     accumulators
         .iter_mut()
         .zip(expressions)
-        .try_for_each(|(accum, expr)| {
+        .zip(filters)
+        .try_for_each(|((accum, expr), filter)| {
             // 1.2
+            let batch = match filter {
+                Some(filter) => batch_filter(batch, filter)?,
+                None => batch.clone(),

Review Comment:
   It would be really nice to figure out how to avoid this clone(). 
   
   Here is one way I found to do so:
   
   ```diff
   diff --git a/datafusion/core/src/physical_plan/aggregates/no_grouping.rs 
b/datafusion/core/src/physical_plan/aggregates/no_grouping.rs
   index 8b770f796..88bab512c 100644
   --- a/datafusion/core/src/physical_plan/aggregates/no_grouping.rs
   +++ b/datafusion/core/src/physical_plan/aggregates/no_grouping.rs
   @@ -29,6 +29,7 @@ use arrow::record_batch::RecordBatch;
    use datafusion_common::Result;
    use datafusion_physical_expr::{AggregateExpr, PhysicalExpr};
    use futures::stream::BoxStream;
   +use std::borrow::Cow;
    use std::sync::Arc;
    use std::task::{Context, Poll};
    
   @@ -101,7 +102,7 @@ impl AggregateStream {
                            let timer = elapsed_compute.timer();
                            let result = aggregate_batch(
                                &this.mode,
   -                            &batch,
   +                            batch,
                                &mut this.accumulators,
                                &this.aggregate_expressions,
                                &this.filter_expressions,
   @@ -174,7 +175,7 @@ impl RecordBatchStream for AggregateStream {
    /// TODO: Make this a member function
    fn aggregate_batch(
        mode: &AggregateMode,
   -    batch: &RecordBatch,
   +    batch: RecordBatch,
        accumulators: &mut [AccumulatorItem],
        expressions: &[Vec<Arc<dyn PhysicalExpr>>],
        filters: &[Option<Arc<dyn PhysicalExpr>>],
   @@ -194,8 +195,8 @@ fn aggregate_batch(
            .try_for_each(|((accum, expr), filter)| {
                // 1.2
                let batch = match filter {
   -                Some(filter) => batch_filter(batch, filter)?,
   -                None => batch.clone(),
   +                Some(filter) => Cow::Owned(batch_filter(&batch, filter)?),
   +                None => Cow::Borrowed(&batch),
                };
                // 1.3
                let values = &expr
   ```



##########
datafusion/core/tests/sql/group_by.rs:
##########
@@ -905,3 +905,220 @@ async fn 
csv_query_group_by_order_by_avg_group_by_substr() -> Result<()> {
     assert_batches_sorted_eq!(expected, &actual);
     Ok(())
 }
+
+#[tokio::test]
+async fn query_group_by_with_filter() -> Result<()> {
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("c1", DataType::Int32, true),
+        Field::new("c2", DataType::Int32, true),
+    ]));
+
+    let data = RecordBatch::try_new(
+        schema.clone(),
+        vec![
+            Arc::new(Int32Array::from(vec![
+                Some(1),
+                Some(1),
+                Some(2),
+                Some(2),
+                Some(3),
+            ])),
+            Arc::new(Int32Array::from(vec![
+                Some(10),
+                Some(20),
+                Some(10),
+                Some(20),
+                Some(10),
+            ])),
+        ],
+    )?;
+
+    let ctx = use_postgres_dialect();
+    ctx.register_batch("test", data)?;
+    let sql =
+        "SELECT c1, SUM(c2) FILTER (WHERE c2 >= 20) as result FROM test GROUP 
BY c1";
+
+    let actual = execute_to_batches(&ctx, sql).await;
+
+    let expected = vec![
+        "+----+--------+",
+        "| c1 | result |",
+        "+----+--------+",
+        "| 1  | 20     |",
+        "| 2  | 20     |",
+        "| 3  |        |",
+        "+----+--------+",
+    ];
+    assert_batches_sorted_eq!(expected, &actual);
+    Ok(())
+}
+
+// test avg since it has two state columns
+#[tokio::test]
+async fn query_group_by_avg_with_filter() -> Result<()> {
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("c1", DataType::Int32, false),
+        Field::new("c2", DataType::Int32, false),
+    ]));
+
+    let data = RecordBatch::try_new(
+        schema.clone(),
+        vec![
+            Arc::new(Int32Array::from(vec![1, 1, 2, 2])),
+            Arc::new(Int32Array::from(vec![10, 20, 30, 40])),
+        ],
+    )?;
+
+    let ctx = use_postgres_dialect();
+    ctx.register_batch("test", data)?;
+
+    let sql =
+        "SELECT c1, AVG(c2) FILTER (WHERE c2 >= 20) AS avg_c2 FROM test GROUP 
BY c1";
+    let actual = execute_to_batches(&ctx, sql).await;
+
+    let expected = vec![
+        "+----+--------+",
+        "| c1 | avg_c2 |",
+        "+----+--------+",
+        "| 1  | 20.0   |",
+        "| 2  | 35.0   |",
+        "+----+--------+",
+    ];
+
+    assert_batches_sorted_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn query_group_by_with_multiple_filters() -> Result<()> {
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("c1", DataType::Int32, false),
+        Field::new("c2", DataType::Int32, false),
+        Field::new("c3", DataType::Int32, false),
+    ]));
+
+    let data = RecordBatch::try_new(
+        schema.clone(),
+        vec![
+            Arc::new(Int32Array::from(vec![1, 1, 2, 2])),
+            Arc::new(Int32Array::from(vec![10, 20, 30, 40])),
+            Arc::new(Int32Array::from(vec![50, 60, 70, 80])),
+        ],
+    )?;
+
+    let ctx = use_postgres_dialect();
+    ctx.register_batch("test", data)?;
+
+    let sql = "SELECT c1, SUM(c2) FILTER (WHERE c2 >= 20) AS sum_c2, AVG(c3) 
FILTER (WHERE c3 <= 70) AS avg_c3 FROM test GROUP BY c1";
+    let actual = execute_to_batches(&ctx, sql).await;
+
+    let expected = vec![
+        "+----+--------+--------+",
+        "| c1 | sum_c2 | avg_c3 |",
+        "+----+--------+--------+",
+        "| 1  | 20     | 55.0   |",
+        "| 2  | 70     | 70.0   |",
+        "+----+--------+--------+",
+    ];
+
+    assert_batches_sorted_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn query_group_by_distinct_with_filter() -> Result<()> {
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("c1", DataType::Int32, false),
+        Field::new("c2", DataType::Int32, false),
+    ]));
+
+    let data = RecordBatch::try_new(
+        schema.clone(),
+        vec![
+            Arc::new(Int32Array::from(vec![1, 1, 2, 2, 2])),
+            Arc::new(Int32Array::from(vec![10, 20, 20, 30, 40])),
+        ],
+    )?;
+
+    let ctx = use_postgres_dialect();
+    ctx.register_batch("test", data)?;
+    let sql = "SELECT c1, COUNT(DISTINCT c2) FILTER (WHERE c2 >= 20) AS 
distinct_c2_count FROM test GROUP BY c1";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+----+-------------------+",
+        "| c1 | distinct_c2_count |",
+        "+----+-------------------+",
+        "| 1  | 1                 |",
+        "| 2  | 3                 |",
+        "+----+-------------------+",
+    ];
+    assert_batches_sorted_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn query_without_group_by_with_filter() -> Result<()> {
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("c1", DataType::Int32, false),
+        Field::new("c2", DataType::Int32, false),
+    ]));
+
+    let data = RecordBatch::try_new(
+        schema.clone(),
+        vec![
+            Arc::new(Int32Array::from(vec![1, 1, 2, 2, 2])),
+            Arc::new(Int32Array::from(vec![10, 20, 20, 30, 40])),
+        ],
+    )?;
+
+    let ctx = use_postgres_dialect();
+    ctx.register_batch("test", data)?;
+    let sql = "SELECT SUM(c2) FILTER (WHERE c2 >= 20) AS sum_c2 FROM test";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+--------+",
+        "| sum_c2 |",
+        "+--------+",
+        "| 110    |",
+        "+--------+",
+    ];
+    assert_batches_sorted_eq!(expected, &actual);
+    Ok(())
+}
+
+// count is special cased by `aggregate_statistics`

Review Comment:
   Also, it would be good to test when the filter filters out all rows (aka the 
input to the aggregate is empty)
   
   ```sql
   SELECT SUM(c2) FILTER (WHERE c2 >= 20000000) AS sum_c2 FROM test
   ```



##########
datafusion/physical-expr/src/aggregate/mod.rs:
##########
@@ -77,6 +77,9 @@ pub trait AggregateExpr: Send + Sync + Debug {
     /// Single-column aggregations such as `sum` return a single value, others 
(e.g. `cov`) return many.
     fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>>;
 
+    /// FILTER (WHERE clause) expression for this aggregate

Review Comment:
   I think it makes more sense to track the filters on the 
`LogicalPlan::GroupBy`, and relevant ExecutionPlan, rather than force all 
AggregateExprs to carry the filter themselves because"
   
   1. The filtering is the same for all aggregates (it doesn't vary by 
aggregate type) so having the filter on the aggregate seems to be a mismatch
   2. If forces user defined aggregates to all do the same thing (carry along a 
filter) and if they make a mistake they could get wrong answers. 



##########
datafusion/core/src/execution/context.rs:
##########
@@ -1510,6 +1515,27 @@ impl SessionState {
         Ok(statement)
     }
 
+    /// Convert a SQL string into an AST Statement
+    pub fn sql_to_statement_with_dialect(

Review Comment:
   I think with this change, it will mean that `sql_to_statement` will 
effectively ignore `ConfigOptions::sql_parser::dialect` which seems confusing. 
   
   Is there a need for a separate  "sql_to_statement_with_dialect" -- rather 
than changing  `sql_to_statement` to use the configured dialect ?



##########
datafusion/core/tests/sql/group_by.rs:
##########
@@ -905,3 +905,220 @@ async fn 
csv_query_group_by_order_by_avg_group_by_substr() -> Result<()> {
     assert_batches_sorted_eq!(expected, &actual);
     Ok(())
 }
+
+#[tokio::test]
+async fn query_group_by_with_filter() -> Result<()> {
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("c1", DataType::Int32, true),
+        Field::new("c2", DataType::Int32, true),
+    ]));
+
+    let data = RecordBatch::try_new(
+        schema.clone(),
+        vec![
+            Arc::new(Int32Array::from(vec![
+                Some(1),
+                Some(1),
+                Some(2),
+                Some(2),
+                Some(3),
+            ])),
+            Arc::new(Int32Array::from(vec![
+                Some(10),
+                Some(20),
+                Some(10),
+                Some(20),
+                Some(10),
+            ])),
+        ],
+    )?;
+
+    let ctx = use_postgres_dialect();
+    ctx.register_batch("test", data)?;
+    let sql =

Review Comment:
   I think another case that is important is a query with an aggregate that 
both does / does not have a filter
   
   For example
   
   
   ```sql
   SELECT c1, 
     SUM(c2) FILTER (WHERE c2 >= 20) as result,
     SUM(c2)  as result_no_filter,
   FROM test GROUP BY c1



##########
datafusion/core/tests/sql/group_by.rs:
##########
@@ -905,3 +905,220 @@ async fn 
csv_query_group_by_order_by_avg_group_by_substr() -> Result<()> {
     assert_batches_sorted_eq!(expected, &actual);
     Ok(())
 }
+

Review Comment:
   What do you think about writing tests in  aggregate.slt instead of .rs?
   
   



##########
datafusion/core/tests/sql/group_by.rs:
##########
@@ -905,3 +905,220 @@ async fn 
csv_query_group_by_order_by_avg_group_by_substr() -> Result<()> {
     assert_batches_sorted_eq!(expected, &actual);
     Ok(())
 }
+
+#[tokio::test]
+async fn query_group_by_with_filter() -> Result<()> {
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("c1", DataType::Int32, true),
+        Field::new("c2", DataType::Int32, true),
+    ]));
+
+    let data = RecordBatch::try_new(
+        schema.clone(),
+        vec![
+            Arc::new(Int32Array::from(vec![
+                Some(1),
+                Some(1),
+                Some(2),
+                Some(2),
+                Some(3),
+            ])),
+            Arc::new(Int32Array::from(vec![
+                Some(10),
+                Some(20),
+                Some(10),
+                Some(20),
+                Some(10),
+            ])),
+        ],
+    )?;
+
+    let ctx = use_postgres_dialect();
+    ctx.register_batch("test", data)?;
+    let sql =
+        "SELECT c1, SUM(c2) FILTER (WHERE c2 >= 20) as result FROM test GROUP 
BY c1";
+
+    let actual = execute_to_batches(&ctx, sql).await;
+
+    let expected = vec![
+        "+----+--------+",
+        "| c1 | result |",
+        "+----+--------+",
+        "| 1  | 20     |",
+        "| 2  | 20     |",
+        "| 3  |        |",
+        "+----+--------+",
+    ];
+    assert_batches_sorted_eq!(expected, &actual);
+    Ok(())
+}
+
+// test avg since it has two state columns
+#[tokio::test]
+async fn query_group_by_avg_with_filter() -> Result<()> {
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("c1", DataType::Int32, false),
+        Field::new("c2", DataType::Int32, false),
+    ]));
+
+    let data = RecordBatch::try_new(
+        schema.clone(),
+        vec![
+            Arc::new(Int32Array::from(vec![1, 1, 2, 2])),
+            Arc::new(Int32Array::from(vec![10, 20, 30, 40])),
+        ],
+    )?;
+
+    let ctx = use_postgres_dialect();
+    ctx.register_batch("test", data)?;
+
+    let sql =
+        "SELECT c1, AVG(c2) FILTER (WHERE c2 >= 20) AS avg_c2 FROM test GROUP 
BY c1";
+    let actual = execute_to_batches(&ctx, sql).await;
+
+    let expected = vec![
+        "+----+--------+",
+        "| c1 | avg_c2 |",
+        "+----+--------+",
+        "| 1  | 20.0   |",
+        "| 2  | 35.0   |",
+        "+----+--------+",
+    ];
+
+    assert_batches_sorted_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn query_group_by_with_multiple_filters() -> Result<()> {
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("c1", DataType::Int32, false),
+        Field::new("c2", DataType::Int32, false),
+        Field::new("c3", DataType::Int32, false),
+    ]));
+
+    let data = RecordBatch::try_new(
+        schema.clone(),
+        vec![
+            Arc::new(Int32Array::from(vec![1, 1, 2, 2])),
+            Arc::new(Int32Array::from(vec![10, 20, 30, 40])),
+            Arc::new(Int32Array::from(vec![50, 60, 70, 80])),
+        ],
+    )?;
+
+    let ctx = use_postgres_dialect();
+    ctx.register_batch("test", data)?;
+
+    let sql = "SELECT c1, SUM(c2) FILTER (WHERE c2 >= 20) AS sum_c2, AVG(c3) 
FILTER (WHERE c3 <= 70) AS avg_c3 FROM test GROUP BY c1";
+    let actual = execute_to_batches(&ctx, sql).await;
+
+    let expected = vec![
+        "+----+--------+--------+",
+        "| c1 | sum_c2 | avg_c3 |",
+        "+----+--------+--------+",
+        "| 1  | 20     | 55.0   |",
+        "| 2  | 70     | 70.0   |",
+        "+----+--------+--------+",
+    ];
+
+    assert_batches_sorted_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn query_group_by_distinct_with_filter() -> Result<()> {
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("c1", DataType::Int32, false),
+        Field::new("c2", DataType::Int32, false),
+    ]));
+
+    let data = RecordBatch::try_new(
+        schema.clone(),
+        vec![
+            Arc::new(Int32Array::from(vec![1, 1, 2, 2, 2])),
+            Arc::new(Int32Array::from(vec![10, 20, 20, 30, 40])),
+        ],
+    )?;
+
+    let ctx = use_postgres_dialect();
+    ctx.register_batch("test", data)?;
+    let sql = "SELECT c1, COUNT(DISTINCT c2) FILTER (WHERE c2 >= 20) AS 
distinct_c2_count FROM test GROUP BY c1";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+----+-------------------+",
+        "| c1 | distinct_c2_count |",
+        "+----+-------------------+",
+        "| 1  | 1                 |",
+        "| 2  | 3                 |",
+        "+----+-------------------+",
+    ];
+    assert_batches_sorted_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn query_without_group_by_with_filter() -> Result<()> {
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("c1", DataType::Int32, false),
+        Field::new("c2", DataType::Int32, false),
+    ]));
+
+    let data = RecordBatch::try_new(
+        schema.clone(),
+        vec![
+            Arc::new(Int32Array::from(vec![1, 1, 2, 2, 2])),
+            Arc::new(Int32Array::from(vec![10, 20, 20, 30, 40])),
+        ],
+    )?;
+
+    let ctx = use_postgres_dialect();
+    ctx.register_batch("test", data)?;
+    let sql = "SELECT SUM(c2) FILTER (WHERE c2 >= 20) AS sum_c2 FROM test";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+--------+",
+        "| sum_c2 |",
+        "+--------+",
+        "| 110    |",
+        "+--------+",
+    ];
+    assert_batches_sorted_eq!(expected, &actual);
+    Ok(())
+}
+
+// count is special cased by `aggregate_statistics`

Review Comment:
   Another test case that I think is important is when the filter is on a 
different column than the aggregate. 
   
   ```
   postgres=# create table test as values (1, 10), (2, 20), (3, 30);
   SELECT 3
   postgres=# select * from test;
    column1 | column2
   ---------+---------
          1 |      10
          2 |      20
          3 |      30
   (3 rows)
   
   postgres=# select sum(column1) FILTER (WHERE column2 < 30) from test;
    sum
   -----
      3
   (1 row)
   ```
   
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to