This is an automated email from the ASF dual-hosted git repository.

jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 5d08325165 Count wildcard alias (#14927)
5d08325165 is described below

commit 5d08325165c1a7b32e5e35164919e83d46735e98
Author: Jay Zhan <[email protected]>
AuthorDate: Wed Mar 5 09:44:42 2025 +0800

    Count wildcard alias (#14927)
    
    * fix alias
    
    * append the string
    
    * window count
    
    * add column
    
    * fmt
    
    * rm todo
    
    * fixed partitioned
    
    * fix test
    
    * update doc
    
    * Suggestion to reduce API surface area
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 .../core/tests/dataframe/dataframe_functions.rs    |   6 +-
 datafusion/core/tests/dataframe/mod.rs             | 306 ++++++++++++++++++---
 datafusion/functions-aggregate/src/count.rs        |  51 +++-
 datafusion/sqllogictest/test_files/subquery.slt    |  38 +++
 4 files changed, 360 insertions(+), 41 deletions(-)

diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs 
b/datafusion/core/tests/dataframe/dataframe_functions.rs
index 28c0740ca7..fec3ab786f 100644
--- a/datafusion/core/tests/dataframe/dataframe_functions.rs
+++ b/datafusion/core/tests/dataframe/dataframe_functions.rs
@@ -1145,9 +1145,9 @@ async fn test_count_wildcard() -> Result<()> {
         .build()
         .unwrap();
 
-    let expected = "Sort: count(Int64(1)) ASC NULLS LAST 
[count(Int64(1)):Int64]\
-    \n  Projection: count(Int64(1)) [count(Int64(1)):Int64]\
-    \n    Aggregate: groupBy=[[test.b]], aggr=[[count(Int64(1))]] [b:UInt32, 
count(Int64(1)):Int64]\
+    let expected = "Sort: count(*) ASC NULLS LAST [count(*):Int64]\
+    \n  Projection: count(*) [count(*):Int64]\
+    \n    Aggregate: groupBy=[[test.b]], aggr=[[count(Int64(1)) AS count(*)]] 
[b:UInt32, count(*):Int64]\
     \n      TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
 
     let formatted_plan = plan.display_indent_schema().to_string();
diff --git a/datafusion/core/tests/dataframe/mod.rs 
b/datafusion/core/tests/dataframe/mod.rs
index 1875180d50..43428d6846 100644
--- a/datafusion/core/tests/dataframe/mod.rs
+++ b/datafusion/core/tests/dataframe/mod.rs
@@ -32,8 +32,7 @@ use arrow::datatypes::{
 };
 use arrow::error::ArrowError;
 use arrow::util::pretty::pretty_format_batches;
-use datafusion_expr::utils::COUNT_STAR_EXPANSION;
-use datafusion_functions_aggregate::count::{count_all, count_udaf};
+use datafusion_functions_aggregate::count::{count_all, count_all_window};
 use datafusion_functions_aggregate::expr_fn::{
     array_agg, avg, count, count_distinct, max, median, min, sum,
 };
@@ -2455,7 +2454,7 @@ async fn test_count_wildcard_on_sort() -> Result<()> {
     let ctx = create_join_context()?;
 
     let sql_results = ctx
-        .sql("select b,count(1) from t1 group by b order by count(1)")
+        .sql("select b, count(*) from t1 group by b order by count(*)")
         .await?
         .explain(false, false)?
         .collect()
@@ -2469,9 +2468,52 @@ async fn test_count_wildcard_on_sort() -> Result<()> {
         .explain(false, false)?
         .collect()
         .await?;
-    //make sure sql plan same with df plan
+
+    let expected_sql_result = 
"+---------------+------------------------------------------------------------------------------------------------------------+\
+    \n| plan_type     | plan                                                   
                                                    |\
+    
\n+---------------+------------------------------------------------------------------------------------------------------------+\
+    \n| logical_plan  | Projection: t1.b, count(*)                             
                                                    |\
+    \n|               |   Sort: count(Int64(1)) AS count(*) AS count(*) ASC 
NULLS LAST                                             |\
+    \n|               |     Projection: t1.b, count(Int64(1)) AS count(*), 
count(Int64(1))                                         |\
+    \n|               |       Aggregate: groupBy=[[t1.b]], 
aggr=[[count(Int64(1))]]                                                |\
+    \n|               |         TableScan: t1 projection=[b]                   
                                                    |\
+    \n| physical_plan | ProjectionExec: expr=[b@0 as b, count(*)@1 as 
count(*)]                                                    |\
+    \n|               |   SortPreservingMergeExec: [count(Int64(1))@2 ASC 
NULLS LAST]                                              |\
+    \n|               |     SortExec: expr=[count(Int64(1))@2 ASC NULLS LAST], 
preserve_partitioning=[true]                        |\
+    \n|               |       ProjectionExec: expr=[b@0 as b, 
count(Int64(1))@1 as count(*), count(Int64(1))@1 as count(Int64(1))] |\
+    \n|               |         AggregateExec: mode=FinalPartitioned, gby=[b@0 
as b], aggr=[count(Int64(1))]                       |\
+    \n|               |           CoalesceBatchesExec: target_batch_size=8192  
                                                    |\
+    \n|               |             RepartitionExec: partitioning=Hash([b@0], 
4), input_partitions=4                               |\
+    \n|               |               RepartitionExec: 
partitioning=RoundRobinBatch(4), input_partitions=1                         |\
+    \n|               |                 AggregateExec: mode=Partial, gby=[b@0 
as b], aggr=[count(Int64(1))]                        |\
+    \n|               |                   DataSourceExec: partitions=1, 
partition_sizes=[1]                                        |\
+    \n|               |                                                        
                                                    |\
+    
\n+---------------+------------------------------------------------------------------------------------------------------------+";
+
     assert_eq!(
-        pretty_format_batches(&sql_results)?.to_string(),
+        expected_sql_result,
+        pretty_format_batches(&sql_results)?.to_string()
+    );
+
+    let expected_df_result = 
"+---------------+--------------------------------------------------------------------------------+\
+\n| plan_type     | plan                                                       
                    |\
+\n+---------------+--------------------------------------------------------------------------------+\
+\n| logical_plan  | Sort: count(*) ASC NULLS LAST                              
                    |\
+\n|               |   Aggregate: groupBy=[[t1.b]], aggr=[[count(Int64(1)) AS 
count(*)]]            |\
+\n|               |     TableScan: t1 projection=[b]                           
                    |\
+\n| physical_plan | SortPreservingMergeExec: [count(*)@1 ASC NULLS LAST]       
                    |\
+\n|               |   SortExec: expr=[count(*)@1 ASC NULLS LAST], 
preserve_partitioning=[true]     |\
+\n|               |     AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], 
aggr=[count(*)]      |\
+\n|               |       CoalesceBatchesExec: target_batch_size=8192          
                    |\
+\n|               |         RepartitionExec: partitioning=Hash([b@0], 4), 
input_partitions=4       |\
+\n|               |           RepartitionExec: 
partitioning=RoundRobinBatch(4), input_partitions=1 |\
+\n|               |             AggregateExec: mode=Partial, gby=[b@0 as b], 
aggr=[count(*)]       |\
+\n|               |               DataSourceExec: partitions=1, 
partition_sizes=[1]                |\
+\n|               |                                                            
                    |\
+\n+---------------+--------------------------------------------------------------------------------+";
+
+    assert_eq!(
+        expected_df_result,
         pretty_format_batches(&df_results)?.to_string()
     );
     Ok(())
@@ -2481,12 +2523,35 @@ async fn test_count_wildcard_on_sort() -> Result<()> {
 async fn test_count_wildcard_on_where_in() -> Result<()> {
     let ctx = create_join_context()?;
     let sql_results = ctx
-        .sql("SELECT a,b FROM t1 WHERE a in (SELECT count(1) FROM t2)")
+        .sql("SELECT a, b FROM t1 WHERE a in (SELECT count(*) FROM t2)")
         .await?
         .explain(false, false)?
         .collect()
         .await?;
 
+    let expected_sql_result = 
"+---------------+------------------------------------------------------------------------------------------------------------------------+\
+\n| plan_type     | plan                                                       
                                                            |\
+\n+---------------+------------------------------------------------------------------------------------------------------------------------+\
+\n| logical_plan  | LeftSemi Join: CAST(t1.a AS Int64) = 
__correlated_sq_1.count(*)                                                      
  |\
+\n|               |   TableScan: t1 projection=[a, b]                          
                                                            |\
+\n|               |   SubqueryAlias: __correlated_sq_1                         
                                                            |\
+\n|               |     Projection: count(Int64(1)) AS count(*)                
                                                            |\
+\n|               |       Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]    
                                                            |\
+\n|               |         TableScan: t2 projection=[]                        
                                                            |\
+\n| physical_plan | CoalesceBatchesExec: target_batch_size=8192                
                                                            |\
+\n|               |   HashJoinExec: mode=Partitioned, join_type=RightSemi, 
on=[(count(*)@0, CAST(t1.a AS Int64)@2)], projection=[a@0, b@1] |\
+\n|               |     ProjectionExec: expr=[4 as count(*)]                   
                                                            |\
+\n|               |       PlaceholderRowExec                                   
                                                            |\
+\n|               |     ProjectionExec: expr=[a@0 as a, b@1 as b, CAST(a@0 AS 
Int64) as CAST(t1.a AS Int64)]                               |\
+\n|               |       DataSourceExec: partitions=1, partition_sizes=[1]    
                                                            |\
+\n|               |                                                            
                                                            |\
+\n+---------------+------------------------------------------------------------------------------------------------------------------------+";
+
+    assert_eq!(
+        expected_sql_result,
+        pretty_format_batches(&sql_results)?.to_string()
+    );
+
     // In the same SessionContext, AliasGenerator will increase subquery_alias 
id by 1
     // 
https://github.com/apache/datafusion/blame/cf45eb9020092943b96653d70fafb143cc362e19/datafusion/optimizer/src/alias.rs#L40-L43
     // for compare difference between sql and df logical plan, we need to 
create a new SessionContext here
@@ -2509,9 +2574,26 @@ async fn test_count_wildcard_on_where_in() -> Result<()> 
{
         .collect()
         .await?;
 
+    let actual_df_result= 
"+---------------+------------------------------------------------------------------------------------------------------------------------+\
+\n| plan_type     | plan                                                       
                                                            |\
+\n+---------------+------------------------------------------------------------------------------------------------------------------------+\
+\n| logical_plan  | LeftSemi Join: CAST(t1.a AS Int64) = 
__correlated_sq_1.count(*)                                                      
  |\
+\n|               |   TableScan: t1 projection=[a, b]                          
                                                            |\
+\n|               |   SubqueryAlias: __correlated_sq_1                         
                                                            |\
+\n|               |     Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS 
count(*)]]                                                      |\
+\n|               |       TableScan: t2 projection=[]                          
                                                            |\
+\n| physical_plan | CoalesceBatchesExec: target_batch_size=8192                
                                                            |\
+\n|               |   HashJoinExec: mode=Partitioned, join_type=RightSemi, 
on=[(count(*)@0, CAST(t1.a AS Int64)@2)], projection=[a@0, b@1] |\
+\n|               |     ProjectionExec: expr=[4 as count(*)]                   
                                                            |\
+\n|               |       PlaceholderRowExec                                   
                                                            |\
+\n|               |     ProjectionExec: expr=[a@0 as a, b@1 as b, CAST(a@0 AS 
Int64) as CAST(t1.a AS Int64)]                               |\
+\n|               |       DataSourceExec: partitions=1, partition_sizes=[1]    
                                                            |\
+\n|               |                                                            
                                                            |\
+\n+---------------+------------------------------------------------------------------------------------------------------------------------+";
+
     // make sure sql plan same with df plan
     assert_eq!(
-        pretty_format_batches(&sql_results)?.to_string(),
+        actual_df_result,
         pretty_format_batches(&df_results)?.to_string()
     );
 
@@ -2522,11 +2604,34 @@ async fn test_count_wildcard_on_where_in() -> 
Result<()> {
 async fn test_count_wildcard_on_where_exist() -> Result<()> {
     let ctx = create_join_context()?;
     let sql_results = ctx
-        .sql("SELECT a, b FROM t1 WHERE EXISTS (SELECT count(1) FROM t2)")
+        .sql("SELECT a, b FROM t1 WHERE EXISTS (SELECT count(*) FROM t2)")
         .await?
         .explain(false, false)?
         .collect()
         .await?;
+
+    let actual_sql_result =
+        
"+---------------+---------------------------------------------------------+\
+    \n| plan_type     | plan                                                   
 |\
+    
\n+---------------+---------------------------------------------------------+\
+    \n| logical_plan  | LeftSemi Join:                                         
 |\
+    \n|               |   TableScan: t1 projection=[a, b]                      
 |\
+    \n|               |   SubqueryAlias: __correlated_sq_1                     
 |\
+    \n|               |     Projection:                                        
 |\
+    \n|               |       Aggregate: groupBy=[[]], 
aggr=[[count(Int64(1))]] |\
+    \n|               |         TableScan: t2 projection=[]                    
 |\
+    \n| physical_plan | NestedLoopJoinExec: join_type=RightSemi                
 |\
+    \n|               |   ProjectionExec: expr=[]                              
 |\
+    \n|               |     PlaceholderRowExec                                 
 |\
+    \n|               |   DataSourceExec: partitions=1, partition_sizes=[1]    
 |\
+    \n|               |                                                        
 |\
+    
\n+---------------+---------------------------------------------------------+";
+
+    assert_eq!(
+        actual_sql_result,
+        pretty_format_batches(&sql_results)?.to_string()
+    );
+
     let df_results = ctx
         .table("t1")
         .await?
@@ -2545,9 +2650,24 @@ async fn test_count_wildcard_on_where_exist() -> 
Result<()> {
         .collect()
         .await?;
 
-    //make sure sql plan same with df plan
+    let actual_df_result = 
"+---------------+---------------------------------------------------------------------+\
+    \n| plan_type     | plan                                                   
             |\
+    
\n+---------------+---------------------------------------------------------------------+\
+    \n| logical_plan  | LeftSemi Join:                                         
             |\
+    \n|               |   TableScan: t1 projection=[a, b]                      
             |\
+    \n|               |   SubqueryAlias: __correlated_sq_1                     
             |\
+    \n|               |     Projection:                                        
             |\
+    \n|               |       Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) 
AS count(*)]] |\
+    \n|               |         TableScan: t2 projection=[]                    
             |\
+    \n| physical_plan | NestedLoopJoinExec: join_type=RightSemi                
             |\
+    \n|               |   ProjectionExec: expr=[]                              
             |\
+    \n|               |     PlaceholderRowExec                                 
             |\
+    \n|               |   DataSourceExec: partitions=1, partition_sizes=[1]    
             |\
+    \n|               |                                                        
             |\
+    
\n+---------------+---------------------------------------------------------------------+";
+
     assert_eq!(
-        pretty_format_batches(&sql_results)?.to_string(),
+        actual_df_result,
         pretty_format_batches(&df_results)?.to_string()
     );
 
@@ -2559,34 +2679,62 @@ async fn test_count_wildcard_on_window() -> Result<()> {
     let ctx = create_join_context()?;
 
     let sql_results = ctx
-        .sql("select count(1) OVER(ORDER BY a DESC RANGE BETWEEN 6 PRECEDING 
AND 2 FOLLOWING)  from t1")
+        .sql("select count(*) OVER(ORDER BY a DESC RANGE BETWEEN 6 PRECEDING 
AND 2 FOLLOWING) from t1")
         .await?
         .explain(false, false)?
         .collect()
         .await?;
+
+    let actual_sql_result = 
"+---------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
 [...]
+\n| plan_type     | plan                                                       
                                                                                
                                                                                
                                                                                
                                                                                
                                                                               
|\
+\n+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+\
+\n| logical_plan  | Projection: count(Int64(1)) ORDER BY [t1.a DESC NULLS 
FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS count(*) ORDER BY [t1.a 
DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING                     
                                                                                
                                                                                
                                                                                
        |\
+\n|               |   WindowAggr: windowExpr=[[count(Int64(1)) ORDER BY [t1.a 
DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]]                   
                                                                                
                                                                                
                                                                                
                                                                                
|\
+\n|               |     TableScan: t1 projection=[a]                           
                                                                                
                                                                                
                                                                                
                                                                                
                                                                               
|\
+\n| physical_plan | ProjectionExec: expr=[count(Int64(1)) ORDER BY [t1.a DESC 
NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@1 as count(*) ORDER BY 
[t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]              
                                                                                
                                                                                
                                                                                
  |\
+\n|               |   BoundedWindowAggExec: wdw=[count(Int64(1)) ORDER BY 
[t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING: Ok(Field { 
name: \"count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 
PRECEDING AND 2 FOLLOWING\", data_type: Int64, nullable: false, dict_id: 0, 
dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, 
start_bound: Preceding(UInt32(6)), end_bound: Following(UInt32(2)), is_causal: 
false }], mode=[Sorted] |\
+\n|               |     SortExec: expr=[a@0 DESC], 
preserve_partitioning=[false]                                                   
                                                                                
                                                                                
                                                                                
                                                                                
                           |\
+\n|               |       DataSourceExec: partitions=1, partition_sizes=[1]    
                                                                                
                                                                                
                                                                                
                                                                                
                                                                               
|\
+\n|               |                                                            
                                                                                
                                                                                
                                                                                
                                                                                
                                                                               
|\
+\n+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+";
+
+    assert_eq!(
+        actual_sql_result,
+        pretty_format_batches(&sql_results)?.to_string()
+    );
+
     let df_results = ctx
         .table("t1")
         .await?
-        .select(vec![Expr::WindowFunction(WindowFunction::new(
-            WindowFunctionDefinition::AggregateUDF(count_udaf()),
-            vec![Expr::Literal(COUNT_STAR_EXPANSION)],
-        ))
-        .order_by(vec![Sort::new(col("a"), false, true)])
-        .window_frame(WindowFrame::new_bounds(
-            WindowFrameUnits::Range,
-            WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
-            WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
-        ))
-        .build()
-        .unwrap()])?
+        .select(vec![count_all_window()
+            .order_by(vec![Sort::new(col("a"), false, true)])
+            .window_frame(WindowFrame::new_bounds(
+                WindowFrameUnits::Range,
+                WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
+                WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
+            ))
+            .build()
+            .unwrap()])?
         .explain(false, false)?
         .collect()
         .await?;
 
-    //make sure sql plan same with df plan
+    let actual_df_result = 
"+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
 [...]
+\n| plan_type     | plan                                                       
                                                                                
                                                                                
                                                                                
                                                                                
                                                                               
|\
+\n+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+\
+\n| logical_plan  | Projection: count(Int64(1)) ORDER BY [t1.a DESC NULLS 
FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING                                
                                                                                
                                                                                
                                                                                
                                                                                
    |\
+\n|               |   WindowAggr: windowExpr=[[count(Int64(1)) ORDER BY [t1.a 
DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]]                   
                                                                                
                                                                                
                                                                                
                                                                                
|\
+\n|               |     TableScan: t1 projection=[a]                           
                                                                                
                                                                                
                                                                                
                                                                                
                                                                               
|\
+\n| physical_plan | ProjectionExec: expr=[count(Int64(1)) ORDER BY [t1.a DESC 
NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@1 as count(Int64(1)) 
ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]     
                                                                                
                                                                                
                                                                                
    |\
+\n|               |   BoundedWindowAggExec: wdw=[count(Int64(1)) ORDER BY 
[t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING: Ok(Field { 
name: \"count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 
PRECEDING AND 2 FOLLOWING\", data_type: Int64, nullable: false, dict_id: 0, 
dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, 
start_bound: Preceding(UInt32(6)), end_bound: Following(UInt32(2)), is_causal: 
false }], mode=[Sorted] |\
+\n|               |     SortExec: expr=[a@0 DESC], 
preserve_partitioning=[false]                                                   
                                                                                
                                                                                
                                                                                
                                                                                
                           |\
+\n|               |       DataSourceExec: partitions=1, partition_sizes=[1]    
                                                                                
                                                                                
                                                                                
                                                                                
                                                                               
|\
+\n|               |                                                            
                                                                                
                                                                                
                                                                                
                                                                                
                                                                               
|\
+\n+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+";
+
     assert_eq!(
-        pretty_format_batches(&df_results)?.to_string(),
-        pretty_format_batches(&sql_results)?.to_string()
+        actual_df_result,
+        pretty_format_batches(&df_results)?.to_string()
     );
 
     Ok(())
@@ -2598,12 +2746,28 @@ async fn test_count_wildcard_on_aggregate() -> 
Result<()> {
     register_alltypes_tiny_pages_parquet(&ctx).await?;
 
     let sql_results = ctx
-        .sql("select count(1) from t1")
+        .sql("select count(*) from t1")
         .await?
         .explain(false, false)?
         .collect()
         .await?;
 
+    let actual_sql_result =
+        
"+---------------+-----------------------------------------------------+\
+\n| plan_type     | plan                                                |\
+\n+---------------+-----------------------------------------------------+\
+\n| logical_plan  | Projection: count(Int64(1)) AS count(*)             |\
+\n|               |   Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] |\
+\n|               |     TableScan: t1 projection=[]                     |\
+\n| physical_plan | ProjectionExec: expr=[4 as count(*)]                |\
+\n|               |   PlaceholderRowExec                                |\
+\n|               |                                                     |\
+\n+---------------+-----------------------------------------------------+";
+    assert_eq!(
+        actual_sql_result,
+        pretty_format_batches(&sql_results)?.to_string()
+    );
+
     // add `.select(vec![count_wildcard()])?` to make sure we can analyze all 
node instead of just top node.
     let df_results = ctx
         .table("t1")
@@ -2614,9 +2778,17 @@ async fn test_count_wildcard_on_aggregate() -> 
Result<()> {
         .collect()
         .await?;
 
-    //make sure sql plan same with df plan
+    let actual_df_result = 
"+---------------+---------------------------------------------------------------+\
+\n| plan_type     | plan                                                       
   |\
+\n+---------------+---------------------------------------------------------------+\
+\n| logical_plan  | Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS 
count(*)]] |\
+\n|               |   TableScan: t1 projection=[]                              
   |\
+\n| physical_plan | ProjectionExec: expr=[4 as count(*)]                       
   |\
+\n|               |   PlaceholderRowExec                                       
   |\
+\n|               |                                                            
   |\
+\n+---------------+---------------------------------------------------------------+";
     assert_eq!(
-        pretty_format_batches(&sql_results)?.to_string(),
+        actual_df_result,
         pretty_format_batches(&df_results)?.to_string()
     );
 
@@ -2628,16 +2800,51 @@ async fn test_count_wildcard_on_where_scalar_subquery() 
-> Result<()> {
     let ctx = create_join_context()?;
 
     let sql_results = ctx
-        .sql("select a,b from t1 where (select count(1) from t2 where t1.a = 
t2.a)>0;")
+        .sql("select a,b from t1 where (select count(*) from t2 where t1.a = 
t2.a)>0;")
         .await?
         .explain(false, false)?
         .collect()
         .await?;
 
+    let actual_sql_result = 
"+---------------+---------------------------------------------------------------------------------------------------------------------------+\
+\n| plan_type     | plan                                                       
                                                               |\
+\n+---------------+---------------------------------------------------------------------------------------------------------------------------+\
+\n| logical_plan  | Projection: t1.a, t1.b                                     
                                                               |\
+\n|               |   Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL 
THEN Int64(0) ELSE __scalar_sq_1.count(*) END > Int64(0)          |\
+\n|               |     Projection: t1.a, t1.b, __scalar_sq_1.count(*), 
__scalar_sq_1.__always_true                                           |\
+\n|               |       Left Join: t1.a = __scalar_sq_1.a                    
                                                               |\
+\n|               |         TableScan: t1 projection=[a, b]                    
                                                               |\
+\n|               |         SubqueryAlias: __scalar_sq_1                       
                                                               |\
+\n|               |           Projection: count(Int64(1)) AS count(*), t2.a, 
Boolean(true) AS __always_true                                   |\
+\n|               |             Aggregate: groupBy=[[t2.a]], 
aggr=[[count(Int64(1))]]                                                        
 |\
+\n|               |               TableScan: t2 projection=[a]                 
                                                               |\
+\n| physical_plan | CoalesceBatchesExec: target_batch_size=8192                
                                                               |\
+\n|               |   FilterExec: CASE WHEN __always_true@3 IS NULL THEN 0 
ELSE count(*)@2 END > 0, projection=[a@0, b@1]                     |\
+\n|               |     CoalesceBatchesExec: target_batch_size=8192            
                                                               |\
+\n|               |       HashJoinExec: mode=Partitioned, join_type=Left, 
on=[(a@0, a@1)], projection=[a@0, b@1, count(*)@2, __always_true@4] |\
+\n|               |         CoalesceBatchesExec: target_batch_size=8192        
                                                               |\
+\n|               |           RepartitionExec: partitioning=Hash([a@0], 4), 
input_partitions=1                                                |\
+\n|               |             DataSourceExec: partitions=1, 
partition_sizes=[1]                                                             
|\
+\n|               |         ProjectionExec: expr=[count(Int64(1))@1 as 
count(*), a@0 as a, true as __always_true]                             |\
+\n|               |           AggregateExec: mode=FinalPartitioned, gby=[a@0 
as a], aggr=[count(Int64(1))]                                    |\
+\n|               |             CoalesceBatchesExec: target_batch_size=8192    
                                                               |\
+\n|               |               RepartitionExec: partitioning=Hash([a@0], 
4), input_partitions=4                                            |\
+\n|               |                 RepartitionExec: 
partitioning=RoundRobinBatch(4), input_partitions=1                             
         |\
+\n|               |                   AggregateExec: mode=Partial, gby=[a@0 as 
a], aggr=[count(Int64(1))]                                     |\
+\n|               |                     DataSourceExec: partitions=1, 
partition_sizes=[1]                                                     |\
+\n|               |                                                            
                                                               |\
+\n+---------------+---------------------------------------------------------------------------------------------------------------------------+";
+    assert_eq!(
+        actual_sql_result,
+        pretty_format_batches(&sql_results)?.to_string()
+    );
+
     // In the same SessionContext, AliasGenerator will increase subquery_alias 
id by 1
     // 
https://github.com/apache/datafusion/blame/cf45eb9020092943b96653d70fafb143cc362e19/datafusion/optimizer/src/alias.rs#L40-L43
     // for compare difference between sql and df logical plan, we need to 
create a new SessionContext here
     let ctx = create_join_context()?;
+    let agg_expr = count_all();
+    let agg_expr_col = col(agg_expr.schema_name().to_string());
     let df_results = ctx
         .table("t1")
         .await?
@@ -2646,8 +2853,8 @@ async fn test_count_wildcard_on_where_scalar_subquery() 
-> Result<()> {
                 ctx.table("t2")
                     .await?
                     .filter(out_ref_col(DataType::UInt32, 
"t1.a").eq(col("t2.a")))?
-                    .aggregate(vec![], vec![count_all()])?
-                    .select(vec![col(count_all().to_string())])?
+                    .aggregate(vec![], vec![agg_expr])?
+                    .select(vec![agg_expr_col])?
                     .into_unoptimized_plan(),
             ))
             .gt(lit(ScalarValue::UInt8(Some(0)))),
@@ -2657,9 +2864,36 @@ async fn test_count_wildcard_on_where_scalar_subquery() 
-> Result<()> {
         .collect()
         .await?;
 
-    //make sure sql plan same with df plan
+    let actual_df_result = 
"+---------------+---------------------------------------------------------------------------------------------------------------------------+\
+\n| plan_type     | plan                                                       
                                                               |\
+\n+---------------+---------------------------------------------------------------------------------------------------------------------------+\
+\n| logical_plan  | Projection: t1.a, t1.b                                     
                                                               |\
+\n|               |   Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL 
THEN Int64(0) ELSE __scalar_sq_1.count(*) END > Int64(0)          |\
+\n|               |     Projection: t1.a, t1.b, __scalar_sq_1.count(*), 
__scalar_sq_1.__always_true                                           |\
+\n|               |       Left Join: t1.a = __scalar_sq_1.a                    
                                                               |\
+\n|               |         TableScan: t1 projection=[a, b]                    
                                                               |\
+\n|               |         SubqueryAlias: __scalar_sq_1                       
                                                               |\
+\n|               |           Projection: count(*), t2.a, Boolean(true) AS 
__always_true                                                      |\
+\n|               |             Aggregate: groupBy=[[t2.a]], 
aggr=[[count(Int64(1)) AS count(*)]]                                            
 |\
+\n|               |               TableScan: t2 projection=[a]                 
                                                               |\
+\n| physical_plan | CoalesceBatchesExec: target_batch_size=8192                
                                                               |\
+\n|               |   FilterExec: CASE WHEN __always_true@3 IS NULL THEN 0 
ELSE count(*)@2 END > 0, projection=[a@0, b@1]                     |\
+\n|               |     CoalesceBatchesExec: target_batch_size=8192            
                                                               |\
+\n|               |       HashJoinExec: mode=Partitioned, join_type=Left, 
on=[(a@0, a@1)], projection=[a@0, b@1, count(*)@2, __always_true@4] |\
+\n|               |         CoalesceBatchesExec: target_batch_size=8192        
                                                               |\
+\n|               |           RepartitionExec: partitioning=Hash([a@0], 4), 
input_partitions=1                                                |\
+\n|               |             DataSourceExec: partitions=1, 
partition_sizes=[1]                                                             
|\
+\n|               |         ProjectionExec: expr=[count(*)@1 as count(*), a@0 
as a, true as __always_true]                                    |\
+\n|               |           AggregateExec: mode=FinalPartitioned, gby=[a@0 
as a], aggr=[count(*)]                                           |\
+\n|               |             CoalesceBatchesExec: target_batch_size=8192    
                                                               |\
+\n|               |               RepartitionExec: partitioning=Hash([a@0], 
4), input_partitions=4                                            |\
+\n|               |                 RepartitionExec: 
partitioning=RoundRobinBatch(4), input_partitions=1                             
         |\
+\n|               |                   AggregateExec: mode=Partial, gby=[a@0 as 
a], aggr=[count(*)]                                            |\
+\n|               |                     DataSourceExec: partitions=1, 
partition_sizes=[1]                                                     |\
+\n|               |                                                            
                                                               |\
+\n+---------------+---------------------------------------------------------------------------------------------------------------------------+";
     assert_eq!(
-        pretty_format_batches(&sql_results)?.to_string(),
+        actual_df_result,
         pretty_format_batches(&df_results)?.to_string()
     );
 
@@ -4228,7 +4462,9 @@ fn create_join_context() -> Result<SessionContext> {
         ],
     )?;
 
-    let ctx = SessionContext::new();
+    let config = SessionConfig::new().with_target_partitions(4);
+    let ctx = SessionContext::new_with_config(config);
+    // let ctx = SessionContext::new();
 
     ctx.register_batch("t1", batch1)?;
     ctx.register_batch("t2", batch2)?;
diff --git a/datafusion/functions-aggregate/src/count.rs 
b/datafusion/functions-aggregate/src/count.rs
index a3339f0fce..2d995b4a41 100644
--- a/datafusion/functions-aggregate/src/count.rs
+++ b/datafusion/functions-aggregate/src/count.rs
@@ -17,6 +17,7 @@
 
 use ahash::RandomState;
 use datafusion_common::stats::Precision;
+use datafusion_expr::expr::WindowFunction;
 use 
datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator;
 use datafusion_macros::user_doc;
 use datafusion_physical_expr::expressions;
@@ -51,7 +52,9 @@ use datafusion_expr::{
     function::AccumulatorArgs, utils::format_state_name, Accumulator, 
AggregateUDFImpl,
     Documentation, EmitTo, GroupsAccumulator, SetMonotonicity, Signature, 
Volatility,
 };
-use datafusion_expr::{Expr, ReversedUDAF, StatisticsArgs, TypeSignature};
+use datafusion_expr::{
+    Expr, ReversedUDAF, StatisticsArgs, TypeSignature, 
WindowFunctionDefinition,
+};
 use datafusion_functions_aggregate_common::aggregate::count_distinct::{
     BytesDistinctCountAccumulator, FloatDistinctCountAccumulator,
     PrimitiveDistinctCountAccumulator,
@@ -79,9 +82,51 @@ pub fn count_distinct(expr: Expr) -> Expr {
     ))
 }
 
-/// Creates aggregation to count all rows, equivalent to `COUNT(*)`, 
`COUNT()`, `COUNT(1)`
+/// Creates aggregation to count all rows.
+///
+/// In SQL this is `SELECT COUNT(*) ... `
+///
+/// The expression is equivalent to `COUNT(*)`, `COUNT()`, `COUNT(1)`, and is
+/// aliased to a column named `"count(*)"` for backward compatibility.
+///
+/// Example
+/// ```
+/// # use datafusion_functions_aggregate::count::count_all;
+/// # use datafusion_expr::col;
+/// // create `count(*)` expression
+/// let expr = count_all();
+/// assert_eq!(expr.schema_name().to_string(), "count(*)");
+/// // if you need to refer to this column, use the `schema_name` function
+/// let expr = col(expr.schema_name().to_string());
+/// ```
 pub fn count_all() -> Expr {
-    count(Expr::Literal(COUNT_STAR_EXPANSION))
+    count(Expr::Literal(COUNT_STAR_EXPANSION)).alias("count(*)")
+}
+
+/// Creates window aggregation to count all rows.
+///
+/// In SQL this is `SELECT COUNT(*) OVER (..) ... `
+///
+/// The expression is equivalent to `COUNT(*)`, `COUNT()`, `COUNT(1)`
+///
+/// Example
+/// ```
+/// # use datafusion_functions_aggregate::count::count_all_window;
+/// # use datafusion_expr::col;
+/// // create `count(*)` OVER ... window function expression
+/// let expr = count_all_window();
+/// assert_eq!(
+///   expr.schema_name().to_string(),
+///   "count(Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED 
FOLLOWING"
+/// );
+/// // if you need to refer to this column, use the `schema_name` function
+/// let expr = col(expr.schema_name().to_string());
+/// ```
+pub fn count_all_window() -> Expr {
+    Expr::WindowFunction(WindowFunction::new(
+        WindowFunctionDefinition::AggregateUDF(count_udaf()),
+        vec![Expr::Literal(COUNT_STAR_EXPANSION)],
+    ))
 }
 
 #[user_doc(
diff --git a/datafusion/sqllogictest/test_files/subquery.slt 
b/datafusion/sqllogictest/test_files/subquery.slt
index 94c9eaf810..207bb72fd5 100644
--- a/datafusion/sqllogictest/test_files/subquery.slt
+++ b/datafusion/sqllogictest/test_files/subquery.slt
@@ -1393,3 +1393,41 @@ item1 1970-01-01T00:00:03 75
 
 statement ok
 drop table source_table;
+
+statement count 0
+drop table t1;
+
+statement count 0
+drop table t2;
+
+statement count 0
+drop table t3;
+
+# test count wildcard
+statement count 0
+create table t1(a int) as values (1);
+
+statement count 0
+create table t2(b int) as values (1);
+
+query I
+SELECT a FROM t1 WHERE EXISTS (SELECT count(*) FROM t2)
+----
+1
+
+query TT
+explain SELECT a FROM t1 WHERE EXISTS (SELECT count(*) FROM t2)
+----
+logical_plan
+01)LeftSemi Join: 
+02)--TableScan: t1 projection=[a]
+03)--SubqueryAlias: __correlated_sq_1
+04)----Projection: 
+05)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]
+06)--------TableScan: t2 projection=[]
+
+statement count 0
+drop table t1;
+
+statement count 0
+drop table t2;


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to