From 842bba651d6fd184402bcc404aad6a709c988614 Mon Sep 17 00:00:00 2001
From: David Christensen <david.christensen@crunchydata.com>
Date: Mon, 11 Mar 2024 10:37:05 -0400
Subject: [PATCH v4] Add GROUP BY ALL

GROUP BY ALL is a form of GROUP BY which adds any TargetExpr that does not
contain an Aggref into the groupClause of the query, effectively making it
exactly equivalent to specifying those same expressions in an explicit GROUP BY list.

Since this is exclusive with any other GROUP BY form, this is fairly simple to
add into the grammar and handle without needing to get into grouping sets or
other more complicated forms.

This greatly improves data exploration in particular, as well as making it so
you don't need to trivially wrap more complicated queries in a subquery or
reproduce long, complicated expressions in the literal GROUP BY.
---
 src/backend/parser/analyze.c             |   8 +-
 src/backend/parser/gram.y                |  16 ++-
 src/backend/parser/parse_clause.c        |  41 ++++++-
 src/backend/utils/adt/ruleutils.c        |   5 +-
 src/include/nodes/parsenodes.h           |   3 +
 src/include/parser/parse_clause.h        |   2 +-
 src/test/regress/expected/aggregates.out | 148 +++++++++++++++++++++++
 src/test/regress/sql/aggregates.sql      |  62 ++++++++++
 8 files changed, 278 insertions(+), 7 deletions(-)

diff --git a/src/backend/parser/analyze.c b/src/backend/parser/analyze.c
index b9763ea1714..991a5919ab1 100644
--- a/src/backend/parser/analyze.c
+++ b/src/backend/parser/analyze.c
@@ -48,6 +48,7 @@
 #include "parser/parse_target.h"
 #include "parser/parse_type.h"
 #include "parser/parsetree.h"
+#include "rewrite/rewriteManip.h"
 #include "utils/backend_status.h"
 #include "utils/builtins.h"
 #include "utils/guc.h"
@@ -1444,8 +1445,10 @@ transformSelectStmt(ParseState *pstate, SelectStmt *stmt)
 											&qry->targetList,
 											qry->sortClause,
 											EXPR_KIND_GROUP_BY,
-											false /* allow SQL92 rules */ );
+											false /* allow SQL92 rules */,
+											stmt->groupAll);
 	qry->groupDistinct = stmt->groupDistinct;
+	qry->groupAll = stmt->groupAll;
 
 	if (stmt->distinctClause == NIL)
 	{
@@ -2934,7 +2937,8 @@ transformPLAssignStmt(ParseState *pstate, PLAssignStmt *stmt)
 											&qry->targetList,
 											qry->sortClause,
 											EXPR_KIND_GROUP_BY,
-											false /* allow SQL92 rules */ );
+											false /* allow SQL92 rules */,
+											qry->groupAll);
 
 	if (sstmt->distinctClause == NIL)
 	{
diff --git a/src/backend/parser/gram.y b/src/backend/parser/gram.y
index 9fd48acb1f8..3d97ceddc1e 100644
--- a/src/backend/parser/gram.y
+++ b/src/backend/parser/gram.y
@@ -120,6 +120,7 @@ typedef struct SelectLimit
 typedef struct GroupClause
 {
 	bool		distinct;
+	bool		all;
 	List	   *list;
 } GroupClause;
 
@@ -12993,6 +12994,7 @@ simple_select:
 					n->whereClause = $6;
 					n->groupClause = ($7)->list;
 					n->groupDistinct = ($7)->distinct;
+					n->groupAll = ($7)->all;
 					n->havingClause = $8;
 					n->windowClause = $9;
 					$$ = (Node *) n;
@@ -13010,6 +13012,7 @@ simple_select:
 					n->whereClause = $6;
 					n->groupClause = ($7)->list;
 					n->groupDistinct = ($7)->distinct;
+					n->groupAll = ($7)->all;
 					n->havingClause = $8;
 					n->windowClause = $9;
 					$$ = (Node *) n;
@@ -13502,12 +13505,21 @@ first_or_next: FIRST_P								{ $$ = 0; }
  * GroupingSet node of some type.
  */
 group_clause:
-			GROUP_P BY set_quantifier group_by_list
+			GROUP_P BY ALL
+				{
+					GroupClause *n = (GroupClause *) palloc(sizeof(GroupClause));
+					n->distinct = false;
+					n->list = NIL;
+					n->all = true;
+					$$ = n;
+				}
+			| GROUP_P BY set_quantifier group_by_list
 				{
 					GroupClause *n = (GroupClause *) palloc(sizeof(GroupClause));
 
 					n->distinct = $3 == SET_QUANTIFIER_DISTINCT;
 					n->list = $4;
+					n->all = false;
 					$$ = n;
 				}
 			| /*EMPTY*/
@@ -13516,6 +13528,7 @@ group_clause:
 
 					n->distinct = false;
 					n->list = NIL;
+					n->all = false;
 					$$ = n;
 				}
 		;
@@ -17618,6 +17631,7 @@ PLpgSQL_Expr: opt_distinct_clause opt_target_list
 					n->whereClause = $4;
 					n->groupClause = ($5)->list;
 					n->groupDistinct = ($5)->distinct;
+					n->groupAll = ($5)->all;
 					n->havingClause = $6;
 					n->windowClause = $7;
 					n->sortClause = $8;
diff --git a/src/backend/parser/parse_clause.c b/src/backend/parser/parse_clause.c
index 9f20a70ce13..7371a98e1f3 100644
--- a/src/backend/parser/parse_clause.c
+++ b/src/backend/parser/parse_clause.c
@@ -2598,6 +2598,9 @@ transformGroupingSet(List **flatresult,
  * GROUP BY items will be added to the targetlist (as resjunk columns)
  * if not already present, so the targetlist must be passed by reference.
  *
+ * If GROUP BY ALL is specified, the groupClause will be inferred to be all
+ * non-aggregate expressions in the targetlist.
+ *
  * This is also used for window PARTITION BY clauses (which act almost the
  * same, but are always interpreted per SQL99 rules).
  *
@@ -2627,11 +2630,12 @@ transformGroupingSet(List **flatresult,
  * sortClause	ORDER BY clause (SortGroupClause nodes)
  * exprKind		expression kind
  * useSQL99		SQL99 rather than SQL92 syntax
+ * groupByAll	is this a GROUP BY ALL statement?
  */
 List *
 transformGroupClause(ParseState *pstate, List *grouplist, List **groupingSets,
 					 List **targetlist, List *sortClause,
-					 ParseExprKind exprKind, bool useSQL99)
+					 ParseExprKind exprKind, bool useSQL99, bool groupByAll)
 {
 	List	   *result = NIL;
 	List	   *flat_grouplist;
@@ -2640,6 +2644,38 @@ transformGroupClause(ParseState *pstate, List *grouplist, List **groupingSets,
 	bool		hasGroupingSets = false;
 	Bitmapset  *seen_local = NULL;
 
+
+	/*
+	 * If we have GROUP BY ALL, we cannot (by definition) have other GROUP BY
+	 * options, including grouping sets.
+	 */
+
+	if (groupByAll)
+	{
+		Assert(grouplist == NULL);
+
+		/*
+		 * Iterate over targets, any non-aggregate gets added as a Target.
+		 * Note that it's not enough to check for a top-level Aggref; we need
+		 * to ensure that any sub-expression here does not include an Aggref
+		 * (for instance an expression such as `sum(col) + 4` should not be
+		 * added as a grouping target).
+		 *
+		 * We also need to skip any resjunk columns, since we only want to
+		 * group by real TLEs.
+		 */
+		foreach_ptr (TargetEntry, tle, *targetlist)
+		{
+			if (tle->resjunk)
+				continue;
+
+			if (!contain_aggs_of_level((Node *)tle->expr, 0))
+				result = addTargetToGroupList(pstate, tle, result, *targetlist, exprLocation((Node *)tle));
+		}
+
+		return result;
+	}
+
 	/*
 	 * Recursively flatten implicit RowExprs. (Technically this is only needed
 	 * for GROUP BY, per the syntax rules for grouping sets, but we do it
@@ -2822,7 +2858,8 @@ transformWindowDefinitions(ParseState *pstate,
 											   targetlist,
 											   orderClause,
 											   EXPR_KIND_WINDOW_PARTITION,
-											   true /* force SQL99 rules */ );
+											   true /* force SQL99 rules */,
+											   false /* group by all */ );
 
 		/*
 		 * And prepare the new WindowClause.
diff --git a/src/backend/utils/adt/ruleutils.c b/src/backend/utils/adt/ruleutils.c
index defcdaa8b34..aa97cc11c78 100644
--- a/src/backend/utils/adt/ruleutils.c
+++ b/src/backend/utils/adt/ruleutils.c
@@ -6186,7 +6186,10 @@ get_basic_select_query(Query *query, deparse_context *context)
 		save_ingroupby = context->inGroupBy;
 		context->inGroupBy = true;
 
-		if (query->groupingSets == NIL)
+		if (query->groupAll)
+			appendContextKeyword(context, " ALL ",
+								 -PRETTYINDENT_STD, PRETTYINDENT_STD, 1);
+		else if (query->groupingSets == NIL)
 		{
 			sep = "";
 			foreach(l, query->groupClause)
diff --git a/src/include/nodes/parsenodes.h b/src/include/nodes/parsenodes.h
index f1706df58fd..fa38d303d2c 100644
--- a/src/include/nodes/parsenodes.h
+++ b/src/include/nodes/parsenodes.h
@@ -215,6 +215,8 @@ typedef struct Query
 
 	List	   *groupClause;	/* a list of SortGroupClause's */
 	bool		groupDistinct;	/* is the group by clause distinct? */
+	bool		groupAll;		/* is the group by for all non-aggregate
+							 	 * columns? */
 
 	List	   *groupingSets;	/* a list of GroupingSet's if present */
 
@@ -2192,6 +2194,7 @@ typedef struct SelectStmt
 	Node	   *whereClause;	/* WHERE qualification */
 	List	   *groupClause;	/* GROUP BY clauses */
 	bool		groupDistinct;	/* Is this GROUP BY DISTINCT? */
+	bool		groupAll;		/* Is this GROUP BY ALL? */
 	Node	   *havingClause;	/* HAVING conditional-expression */
 	List	   *windowClause;	/* WINDOW window_name AS (...), ... */
 
diff --git a/src/include/parser/parse_clause.h b/src/include/parser/parse_clause.h
index 3e9894926de..aa8b54ab5c4 100644
--- a/src/include/parser/parse_clause.h
+++ b/src/include/parser/parse_clause.h
@@ -28,7 +28,7 @@ extern Node *transformLimitClause(ParseState *pstate, Node *clause,
 extern List *transformGroupClause(ParseState *pstate, List *grouplist,
 								  List **groupingSets,
 								  List **targetlist, List *sortClause,
-								  ParseExprKind exprKind, bool useSQL99);
+								  ParseExprKind exprKind, bool useSQL99, bool groupByAll);
 extern List *transformSortClause(ParseState *pstate, List *orderlist,
 								 List **targetlist, ParseExprKind exprKind,
 								 bool useSQL99);
diff --git a/src/test/regress/expected/aggregates.out b/src/test/regress/expected/aggregates.out
index 1f24f6ffd1f..76e7728c678 100644
--- a/src/test/regress/expected/aggregates.out
+++ b/src/test/regress/expected/aggregates.out
@@ -1557,6 +1557,154 @@ drop table t2;
 drop table t3;
 drop table p_t1;
 --
+-- Test GROUP BY ALL
+--
+-- We don't care about the data here, just the proper transformation of the
+-- results, so test some queries and verify the EXPLAIN plans.
+CREATE TEMP TABLE t1 (
+  a int,
+  b int,
+  c int
+);
+-- basic field check
+EXPLAIN (COSTS OFF, TIMING OFF, SUMMARY OFF, BUFFERS OFF) SELECT b, COUNT(*) FROM t1 GROUP BY ALL;
+      QUERY PLAN      
+----------------------
+ HashAggregate
+   Group Key: b
+   ->  Seq Scan on t1
+(3 rows)
+
+-- throw a null in the values too
+EXPLAIN (COSTS OFF, TIMING OFF, SUMMARY OFF, BUFFERS OFF) SELECT a, COUNT(a) FROM t1 GROUP BY ALL;
+      QUERY PLAN      
+----------------------
+ HashAggregate
+   Group Key: a
+   ->  Seq Scan on t1
+(3 rows)
+
+-- multiple columns, non-consecutive order
+EXPLAIN (COSTS OFF, TIMING OFF, SUMMARY OFF, BUFFERS OFF) SELECT a, SUM(b), b FROM t1 GROUP BY ALL;
+      QUERY PLAN      
+----------------------
+ HashAggregate
+   Group Key: a, b
+   ->  Seq Scan on t1
+(3 rows)
+
+-- multi columns, no aggregate
+EXPLAIN (COSTS OFF, TIMING OFF, SUMMARY OFF, BUFFERS OFF) SELECT a + b FROM t1 GROUP BY ALL;
+      QUERY PLAN      
+----------------------
+ HashAggregate
+   Group Key: (a + b)
+   ->  Seq Scan on t1
+(3 rows)
+
+-- non-top-level expression
+EXPLAIN (COSTS OFF, TIMING OFF, SUMMARY OFF, BUFFERS OFF) SELECT a, SUM(b) + 4 FROM t1 GROUP BY ALL;
+      QUERY PLAN      
+----------------------
+ HashAggregate
+   Group Key: a
+   ->  Seq Scan on t1
+(3 rows)
+
+-- including grouped column
+EXPLAIN (COSTS OFF, TIMING OFF, SUMMARY OFF, BUFFERS OFF) SELECT a, SUM(b) + a FROM t1 GROUP BY ALL;
+      QUERY PLAN      
+----------------------
+ HashAggregate
+   Group Key: a
+   ->  Seq Scan on t1
+(3 rows)
+
+-- oops all aggregates
+EXPLAIN (COSTS OFF, TIMING OFF, SUMMARY OFF, BUFFERS OFF) SELECT COUNT(a), SUM(b) FROM t1 GROUP BY ALL;
+      QUERY PLAN      
+----------------------
+ Aggregate
+   ->  Seq Scan on t1
+(2 rows)
+
+-- empty column list
+EXPLAIN (COSTS OFF, TIMING OFF, SUMMARY OFF, BUFFERS OFF) SELECT FROM t1 GROUP BY ALL;
+   QUERY PLAN   
+----------------
+ Seq Scan on t1
+(1 row)
+
+-- filter
+EXPLAIN (COSTS OFF, TIMING OFF, SUMMARY OFF, BUFFERS OFF) SELECT a, COUNT(a) FILTER(WHERE b = 2) FROM t1 GROUP BY ALL;
+      QUERY PLAN      
+----------------------
+ HashAggregate
+   Group Key: a
+   ->  Seq Scan on t1
+(3 rows)
+
+-- all cols
+EXPLAIN (COSTS OFF, TIMING OFF, SUMMARY OFF, BUFFERS OFF) SELECT * FROM t1 GROUP BY ALL;
+      QUERY PLAN      
+----------------------
+ HashAggregate
+   Group Key: a, b, c
+   ->  Seq Scan on t1
+(3 rows)
+
+EXPLAIN (COSTS OFF, TIMING OFF, SUMMARY OFF, BUFFERS OFF) SELECT *, count(*) FROM t1 GROUP BY ALL;
+      QUERY PLAN      
+----------------------
+ HashAggregate
+   Group Key: a, b, c
+   ->  Seq Scan on t1
+(3 rows)
+
+-- expression without including aggregate columns
+EXPLAIN (COSTS OFF, TIMING OFF, SUMMARY OFF, BUFFERS OFF) SELECT a, SUM(b) + c FROM t1 GROUP BY ALL;
+ERROR:  column "t1.c" must appear in the GROUP BY clause or be used in an aggregate function
+LINE 1: ...OFF, SUMMARY OFF, BUFFERS OFF) SELECT a, SUM(b) + c FROM t1 ...
+                                                             ^
+-- subquery
+EXPLAIN (COSTS OFF, TIMING OFF, SUMMARY OFF, BUFFERS OFF) SELECT (SELECT a FROM t1 LIMIT 1), SUM(b) FROM t1 GROUP BY ALL;
+            QUERY PLAN             
+-----------------------------------
+ GroupAggregate
+   InitPlan 1
+     ->  Limit
+           ->  Seq Scan on t1 t1_1
+   ->  Seq Scan on t1
+(5 rows)
+
+-- cte
+EXPLAIN (COSTS OFF, TIMING OFF, SUMMARY OFF, BUFFERS OFF) WITH a_query AS (SELECT a, SUM(b) AS sum_b FROM t1 GROUP BY ALL)
+SELECT AVG(a), sum_b  FROM a_query GROUP BY ALL;
+         QUERY PLAN         
+----------------------------
+ HashAggregate
+   Group Key: sum(t1.b)
+   ->  HashAggregate
+         Group Key: t1.a
+         ->  Seq Scan on t1
+(5 rows)
+
+-- verify deparse
+CREATE VIEW v1 AS SELECT b, COUNT(*) FROM t1 GROUP BY ALL;
+NOTICE:  view "v1" will be a temporary view
+SELECT pg_get_viewdef('v1'::regclass);
+    pg_get_viewdef     
+-----------------------
+  SELECT b,           +
+     count(*) AS count+
+    FROM t1           +
+   GROUP BY           +
+   ALL ;
+(1 row)
+
+DROP VIEW v1;
+DROP TABLE t1;
+--
 -- Test GROUP BY matching of join columns that are type-coerced due to USING
 --
 create temp table t1(f1 int, f2 int);
diff --git a/src/test/regress/sql/aggregates.sql b/src/test/regress/sql/aggregates.sql
index 62540b1ffa4..10c6d4dd1a4 100644
--- a/src/test/regress/sql/aggregates.sql
+++ b/src/test/regress/sql/aggregates.sql
@@ -549,6 +549,68 @@ drop table t2;
 drop table t3;
 drop table p_t1;
 
+
+--
+-- Test GROUP BY ALL
+--
+
+-- We don't care about the data here, just the proper transformation of the
+-- results, so test some queries and verify the EXPLAIN plans.
+
+CREATE TEMP TABLE t1 (
+  a int,
+  b int,
+  c int
+);
+
+-- basic field check
+EXPLAIN (COSTS OFF, TIMING OFF, SUMMARY OFF, BUFFERS OFF) SELECT b, COUNT(*) FROM t1 GROUP BY ALL;
+
+-- throw a null in the values too
+EXPLAIN (COSTS OFF, TIMING OFF, SUMMARY OFF, BUFFERS OFF) SELECT a, COUNT(a) FROM t1 GROUP BY ALL;
+
+-- multiple columns, non-consecutive order
+EXPLAIN (COSTS OFF, TIMING OFF, SUMMARY OFF, BUFFERS OFF) SELECT a, SUM(b), b FROM t1 GROUP BY ALL;
+
+-- multi columns, no aggregate
+EXPLAIN (COSTS OFF, TIMING OFF, SUMMARY OFF, BUFFERS OFF) SELECT a + b FROM t1 GROUP BY ALL;
+
+-- non-top-level expression
+EXPLAIN (COSTS OFF, TIMING OFF, SUMMARY OFF, BUFFERS OFF) SELECT a, SUM(b) + 4 FROM t1 GROUP BY ALL;
+
+-- including grouped column
+EXPLAIN (COSTS OFF, TIMING OFF, SUMMARY OFF, BUFFERS OFF) SELECT a, SUM(b) + a FROM t1 GROUP BY ALL;
+
+-- oops all aggregates
+EXPLAIN (COSTS OFF, TIMING OFF, SUMMARY OFF, BUFFERS OFF) SELECT COUNT(a), SUM(b) FROM t1 GROUP BY ALL;
+
+-- empty column list
+EXPLAIN (COSTS OFF, TIMING OFF, SUMMARY OFF, BUFFERS OFF) SELECT FROM t1 GROUP BY ALL;
+
+-- filter
+EXPLAIN (COSTS OFF, TIMING OFF, SUMMARY OFF, BUFFERS OFF) SELECT a, COUNT(a) FILTER(WHERE b = 2) FROM t1 GROUP BY ALL;
+
+-- all cols
+EXPLAIN (COSTS OFF, TIMING OFF, SUMMARY OFF, BUFFERS OFF) SELECT * FROM t1 GROUP BY ALL;
+EXPLAIN (COSTS OFF, TIMING OFF, SUMMARY OFF, BUFFERS OFF) SELECT *, count(*) FROM t1 GROUP BY ALL;
+
+-- expression without including aggregate columns
+EXPLAIN (COSTS OFF, TIMING OFF, SUMMARY OFF, BUFFERS OFF) SELECT a, SUM(b) + c FROM t1 GROUP BY ALL;
+
+-- subquery
+EXPLAIN (COSTS OFF, TIMING OFF, SUMMARY OFF, BUFFERS OFF) SELECT (SELECT a FROM t1 LIMIT 1), SUM(b) FROM t1 GROUP BY ALL;
+
+-- cte
+EXPLAIN (COSTS OFF, TIMING OFF, SUMMARY OFF, BUFFERS OFF) WITH a_query AS (SELECT a, SUM(b) AS sum_b FROM t1 GROUP BY ALL)
+SELECT AVG(a), sum_b  FROM a_query GROUP BY ALL;
+
+-- verify deparse
+CREATE VIEW v1 AS SELECT b, COUNT(*) FROM t1 GROUP BY ALL;
+SELECT pg_get_viewdef('v1'::regclass);
+
+DROP VIEW v1;
+DROP TABLE t1;
+
 --
 -- Test GROUP BY matching of join columns that are type-coerced due to USING
 --
-- 
2.49.0

