From 2d99e7e2e604c01ec5bb306888daa2562a4dbdfd Mon Sep 17 00:00:00 2001
From: Richard Guo <guofenglinux@gmail.com>
Date: Thu, 16 May 2024 06:17:37 +0000
Subject: [PATCH v6 1/2] Introduce a RTE for the grouping step

---
 .../postgres_fdw/expected/postgres_fdw.out    |   2 +-
 src/backend/commands/explain.c                |  24 ++-
 src/backend/nodes/nodeFuncs.c                 |  14 ++
 src/backend/nodes/outfuncs.c                  |   3 +
 src/backend/nodes/print.c                     |   4 +
 src/backend/nodes/readfuncs.c                 |   3 +
 src/backend/optimizer/path/allpaths.c         |   4 +
 src/backend/optimizer/path/equivclass.c       |  12 ++
 src/backend/optimizer/plan/initsplan.c        |   4 +
 src/backend/optimizer/plan/planner.c          |  32 ++-
 src/backend/optimizer/plan/setrefs.c          |   1 +
 src/backend/optimizer/prep/prepjointree.c     |   9 +-
 src/backend/optimizer/util/var.c              | 147 +++++++++++++
 src/backend/parser/parse_agg.c                | 201 ++++++++++++------
 src/backend/parser/parse_relation.c           |  79 ++++++-
 src/backend/parser/parse_target.c             |   2 +
 src/backend/utils/adt/ruleutils.c             |  20 +-
 src/include/commands/explain.h                |   1 +
 src/include/nodes/nodeFuncs.h                 |   2 +
 src/include/nodes/parsenodes.h                |   9 +
 src/include/nodes/pathnodes.h                 |   5 +
 src/include/optimizer/optimizer.h             |   1 +
 src/include/parser/parse_node.h               |   2 +
 src/include/parser/parse_relation.h           |   2 +
 src/test/regress/expected/groupingsets.out    |  49 +++++
 src/test/regress/sql/groupingsets.sql         |  23 ++
 26 files changed, 576 insertions(+), 79 deletions(-)

diff --git a/contrib/postgres_fdw/expected/postgres_fdw.out b/contrib/postgres_fdw/expected/postgres_fdw.out
index 078b8a966f..edc8f1d51b 100644
--- a/contrib/postgres_fdw/expected/postgres_fdw.out
+++ b/contrib/postgres_fdw/expected/postgres_fdw.out
@@ -3669,7 +3669,7 @@ select count(*), sum(t1.c1), avg(t2.c1) from (select c1 from ft4 where c1 betwee
  Foreign Scan
    Output: (count(*)), (sum(ft4.c1)), (avg(ft5.c1))
    Relations: Aggregate on ((public.ft4) FULL JOIN (public.ft5))
-   Remote SQL: SELECT count(*), sum(s4.c1), avg(s5.c1) FROM ((SELECT c1 FROM "S 1"."T 3" WHERE ((c1 >= 50)) AND ((c1 <= 60))) s4(c1) FULL JOIN (SELECT c1 FROM "S 1"."T 4" WHERE ((c1 >= 50)) AND ((c1 <= 60))) s5(c1) ON (((s4.c1 = s5.c1))))
+   Remote SQL: SELECT count(*), sum(s5.c1), avg(s6.c1) FROM ((SELECT c1 FROM "S 1"."T 3" WHERE ((c1 >= 50)) AND ((c1 <= 60))) s5(c1) FULL JOIN (SELECT c1 FROM "S 1"."T 4" WHERE ((c1 >= 50)) AND ((c1 <= 60))) s6(c1) ON (((s5.c1 = s6.c1))))
 (4 rows)
 
 select count(*), sum(t1.c1), avg(t2.c1) from (select c1 from ft4 where c1 between 50 and 60) t1 full join (select c1 from ft5 where c1 between 50 and 60) t2 on (t1.c1 = t2.c1);
diff --git a/src/backend/commands/explain.c b/src/backend/commands/explain.c
index 94511a5a02..57a63cb92e 100644
--- a/src/backend/commands/explain.c
+++ b/src/backend/commands/explain.c
@@ -877,6 +877,7 @@ ExplainPrintPlan(ExplainState *es, QueryDesc *queryDesc)
 {
 	Bitmapset  *rels_used = NULL;
 	PlanState  *ps;
+	ListCell   *lc;
 
 	/* Set up ExplainState fields associated with this plan tree */
 	Assert(queryDesc->plannedstmt != NULL);
@@ -887,6 +888,17 @@ ExplainPrintPlan(ExplainState *es, QueryDesc *queryDesc)
 	es->deparse_cxt = deparse_context_for_plan_tree(queryDesc->plannedstmt,
 													es->rtable_names);
 	es->printed_subplans = NULL;
+	es->rtable_size = list_length(es->rtable);
+	foreach (lc, es->rtable)
+	{
+		RangeTblEntry *rte = lfirst_node(RangeTblEntry, lc);
+
+		if (rte->rtekind == RTE_GROUP)
+		{
+			es->rtable_size--;
+			break;
+		}
+	}
 
 	/*
 	 * Sometimes we mark a Gather node as "invisible", which means that it's
@@ -2463,7 +2475,7 @@ show_plan_tlist(PlanState *planstate, List *ancestors, ExplainState *es)
 	context = set_deparse_context_plan(es->deparse_cxt,
 									   plan,
 									   ancestors);
-	useprefix = list_length(es->rtable) > 1;
+	useprefix = es->rtable_size > 1;
 
 	/* Deparse each result column (we now include resjunk ones) */
 	foreach(lc, plan->targetlist)
@@ -2547,7 +2559,7 @@ show_upper_qual(List *qual, const char *qlabel,
 {
 	bool		useprefix;
 
-	useprefix = (list_length(es->rtable) > 1 || es->verbose);
+	useprefix = (es->rtable_size > 1 || es->verbose);
 	show_qual(qual, qlabel, planstate, ancestors, useprefix, es);
 }
 
@@ -2637,7 +2649,7 @@ show_grouping_sets(PlanState *planstate, Agg *agg,
 	context = set_deparse_context_plan(es->deparse_cxt,
 									   planstate->plan,
 									   ancestors);
-	useprefix = (list_length(es->rtable) > 1 || es->verbose);
+	useprefix = (es->rtable_size > 1 || es->verbose);
 
 	ExplainOpenGroup("Grouping Sets", "Grouping Sets", false, es);
 
@@ -2777,7 +2789,7 @@ show_sort_group_keys(PlanState *planstate, const char *qlabel,
 	context = set_deparse_context_plan(es->deparse_cxt,
 									   plan,
 									   ancestors);
-	useprefix = (list_length(es->rtable) > 1 || es->verbose);
+	useprefix = (es->rtable_size > 1 || es->verbose);
 
 	for (keyno = 0; keyno < nkeys; keyno++)
 	{
@@ -2889,7 +2901,7 @@ show_tablesample(TableSampleClause *tsc, PlanState *planstate,
 	context = set_deparse_context_plan(es->deparse_cxt,
 									   planstate->plan,
 									   ancestors);
-	useprefix = list_length(es->rtable) > 1;
+	useprefix = es->rtable_size > 1;
 
 	/* Get the tablesample method name */
 	method_name = get_func_name(tsc->tsmhandler);
@@ -3339,7 +3351,7 @@ show_memoize_info(MemoizeState *mstate, List *ancestors, ExplainState *es)
 	 * It's hard to imagine having a memoize node with fewer than 2 RTEs, but
 	 * let's just keep the same useprefix logic as elsewhere in this file.
 	 */
-	useprefix = list_length(es->rtable) > 1 || es->verbose;
+	useprefix = es->rtable_size > 1 || es->verbose;
 
 	/* Set up deparsing context */
 	context = set_deparse_context_plan(es->deparse_cxt,
diff --git a/src/backend/nodes/nodeFuncs.c b/src/backend/nodes/nodeFuncs.c
index 89ee4b61f2..6f0f8e8c54 100644
--- a/src/backend/nodes/nodeFuncs.c
+++ b/src/backend/nodes/nodeFuncs.c
@@ -2862,6 +2862,11 @@ range_table_entry_walker_impl(RangeTblEntry *rte,
 		case RTE_RESULT:
 			/* nothing to do */
 			break;
+		case RTE_GROUP:
+			if (!(flags & QTW_IGNORE_GROUPEXPRS))
+				if (WALK(rte->groupexprs))
+					return true;
+			break;
 	}
 
 	if (WALK(rte->securityQuals))
@@ -3900,6 +3905,15 @@ range_table_mutator_impl(List *rtable,
 			case RTE_RESULT:
 				/* nothing to do */
 				break;
+			case RTE_GROUP:
+				if (!(flags & QTW_IGNORE_GROUPEXPRS))
+					MUTATE(newrte->groupexprs, rte->groupexprs, List *);
+				else
+				{
+					/* else, copy group exprs as-is */
+					newrte->groupexprs = copyObject(rte->groupexprs);
+				}
+				break;
 		}
 		MUTATE(newrte->securityQuals, rte->securityQuals, List *);
 		newrt = lappend(newrt, newrte);
diff --git a/src/backend/nodes/outfuncs.c b/src/backend/nodes/outfuncs.c
index 3337b77ae6..9827cf16be 100644
--- a/src/backend/nodes/outfuncs.c
+++ b/src/backend/nodes/outfuncs.c
@@ -562,6 +562,9 @@ _outRangeTblEntry(StringInfo str, const RangeTblEntry *node)
 		case RTE_RESULT:
 			/* no extra fields */
 			break;
+		case RTE_GROUP:
+			WRITE_NODE_FIELD(groupexprs);
+			break;
 		default:
 			elog(ERROR, "unrecognized RTE kind: %d", (int) node->rtekind);
 			break;
diff --git a/src/backend/nodes/print.c b/src/backend/nodes/print.c
index 02798f4482..03416e8f4a 100644
--- a/src/backend/nodes/print.c
+++ b/src/backend/nodes/print.c
@@ -300,6 +300,10 @@ print_rt(const List *rtable)
 				printf("%d\t%s\t[result]",
 					   i, rte->eref->aliasname);
 				break;
+			case RTE_GROUP:
+				printf("%d\t%s\t[group]",
+					   i, rte->eref->aliasname);
+				break;
 			default:
 				printf("%d\t%s\t[unknown rtekind]",
 					   i, rte->eref->aliasname);
diff --git a/src/backend/nodes/readfuncs.c b/src/backend/nodes/readfuncs.c
index c4d01a441a..818e472a3b 100644
--- a/src/backend/nodes/readfuncs.c
+++ b/src/backend/nodes/readfuncs.c
@@ -422,6 +422,9 @@ _readRangeTblEntry(void)
 		case RTE_RESULT:
 			/* no extra fields */
 			break;
+		case RTE_GROUP:
+			READ_NODE_FIELD(groupexprs);
+			break;
 		default:
 			elog(ERROR, "unrecognized RTE kind: %d",
 				 (int) local_node->rtekind);
diff --git a/src/backend/optimizer/path/allpaths.c b/src/backend/optimizer/path/allpaths.c
index 4895cee994..2ee478195f 100644
--- a/src/backend/optimizer/path/allpaths.c
+++ b/src/backend/optimizer/path/allpaths.c
@@ -731,6 +731,10 @@ set_rel_consider_parallel(PlannerInfo *root, RelOptInfo *rel,
 		case RTE_RESULT:
 			/* RESULT RTEs, in themselves, are no problem. */
 			break;
+		case RTE_GROUP:
+			/* Shouldn't happen; we're only considering baserels here. */
+			Assert(false);
+			return;
 	}
 
 	/*
diff --git a/src/backend/optimizer/path/equivclass.c b/src/backend/optimizer/path/equivclass.c
index 21ce1ae2e1..61c450bb99 100644
--- a/src/backend/optimizer/path/equivclass.c
+++ b/src/backend/optimizer/path/equivclass.c
@@ -737,6 +737,10 @@ get_eclass_for_sort_expr(PlannerInfo *root,
 		{
 			RelOptInfo *rel = root->simple_rel_array[i];
 
+			/* ignore GROUP RTE */
+			if (i == root->group_rtindex)
+				continue;
+
 			if (rel == NULL)	/* must be an outer join */
 			{
 				Assert(bms_is_member(i, root->outer_join_rels));
@@ -1098,6 +1102,10 @@ generate_base_implied_equalities(PlannerInfo *root)
 		{
 			RelOptInfo *rel = root->simple_rel_array[i];
 
+			/* ignore GROUP RTE */
+			if (i == root->group_rtindex)
+				continue;
+
 			if (rel == NULL)	/* must be an outer join */
 			{
 				Assert(bms_is_member(i, root->outer_join_rels));
@@ -3353,6 +3361,10 @@ get_eclass_indexes_for_relids(PlannerInfo *root, Relids relids)
 	{
 		RelOptInfo *rel = root->simple_rel_array[i];
 
+		/* ignore GROUP RTE */
+		if (i == root->group_rtindex)
+			continue;
+
 		if (rel == NULL)		/* must be an outer join */
 		{
 			Assert(bms_is_member(i, root->outer_join_rels));
diff --git a/src/backend/optimizer/plan/initsplan.c b/src/backend/optimizer/plan/initsplan.c
index e2c68fe6f9..48fad35051 100644
--- a/src/backend/optimizer/plan/initsplan.c
+++ b/src/backend/optimizer/plan/initsplan.c
@@ -1328,6 +1328,10 @@ mark_rels_nulled_by_join(PlannerInfo *root, Index ojrelid,
 	{
 		RelOptInfo *rel = root->simple_rel_array[relid];
 
+		/* ignore GROUP RTE */
+		if (relid == root->group_rtindex)
+			continue;
+
 		if (rel == NULL)		/* must be an outer join */
 		{
 			Assert(bms_is_member(relid, root->outer_join_rels));
diff --git a/src/backend/optimizer/plan/planner.c b/src/backend/optimizer/plan/planner.c
index 032818423f..4a4a4d4114 100644
--- a/src/backend/optimizer/plan/planner.c
+++ b/src/backend/optimizer/plan/planner.c
@@ -748,6 +748,7 @@ subquery_planner(PlannerGlobal *glob, Query *parse, PlannerInfo *parent_root,
 	 */
 	root->hasJoinRTEs = false;
 	root->hasLateralRTEs = false;
+	root->group_rtindex = 0;
 	hasOuterJoins = false;
 	hasResultRTEs = false;
 	foreach(l, parse->rtable)
@@ -781,6 +782,10 @@ subquery_planner(PlannerGlobal *glob, Query *parse, PlannerInfo *parent_root,
 			case RTE_RESULT:
 				hasResultRTEs = true;
 				break;
+			case RTE_GROUP:
+				Assert(parse->hasGroupRTE);
+				root->group_rtindex = list_cell_number(parse->rtable, l) + 1;
+				break;
 			default:
 				/* No work here for other RTE types */
 				break;
@@ -836,10 +841,6 @@ subquery_planner(PlannerGlobal *glob, Query *parse, PlannerInfo *parent_root,
 		preprocess_expression(root, (Node *) parse->targetList,
 							  EXPRKIND_TARGET);
 
-	/* Constant-folding might have removed all set-returning functions */
-	if (parse->hasTargetSRFs)
-		parse->hasTargetSRFs = expression_returns_set((Node *) parse->targetList);
-
 	newWithCheckOptions = NIL;
 	foreach(l, parse->withCheckOptions)
 	{
@@ -969,6 +970,13 @@ subquery_planner(PlannerGlobal *glob, Query *parse, PlannerInfo *parent_root,
 			rte->values_lists = (List *)
 				preprocess_expression(root, (Node *) rte->values_lists, kind);
 		}
+		else if (rte->rtekind == RTE_GROUP)
+		{
+			/* Preprocess the groupexprs list fully */
+			rte->groupexprs = (List *)
+				preprocess_expression(root, (Node *) rte->groupexprs,
+									  EXPRKIND_TARGET);
+		}
 
 		/*
 		 * Process each element of the securityQuals list as if it were a
@@ -984,6 +992,22 @@ subquery_planner(PlannerGlobal *glob, Query *parse, PlannerInfo *parent_root,
 		}
 	}
 
+	/*
+	 * Replace any Vars in the subquery's targetlist and havingQual that
+	 * reference GROUP outputs with the underlying grouping expressions.
+	 */
+	if (parse->hasGroupRTE)
+	{
+		parse->targetList = (List *)
+			flatten_group_exprs(root, root->parse, (Node *) parse->targetList);
+		parse->havingQual =
+			flatten_group_exprs(root, root->parse, parse->havingQual);
+	}
+
+	/* Constant-folding might have removed all set-returning functions */
+	if (parse->hasTargetSRFs)
+		parse->hasTargetSRFs = expression_returns_set((Node *) parse->targetList);
+
 	/*
 	 * Now that we are done preprocessing expressions, and in particular done
 	 * flattening join alias variables, get rid of the joinaliasvars lists.
diff --git a/src/backend/optimizer/plan/setrefs.c b/src/backend/optimizer/plan/setrefs.c
index 37abcb4701..631d4d2c70 100644
--- a/src/backend/optimizer/plan/setrefs.c
+++ b/src/backend/optimizer/plan/setrefs.c
@@ -557,6 +557,7 @@ add_rte_to_flat_rtable(PlannerGlobal *glob, List *rteperminfos,
 	newrte->coltypes = NIL;
 	newrte->coltypmods = NIL;
 	newrte->colcollations = NIL;
+	newrte->groupexprs = NIL;
 	newrte->securityQuals = NIL;
 
 	glob->finalrtable = lappend(glob->finalrtable, newrte);
diff --git a/src/backend/optimizer/prep/prepjointree.c b/src/backend/optimizer/prep/prepjointree.c
index 5482ab85a7..728c07f464 100644
--- a/src/backend/optimizer/prep/prepjointree.c
+++ b/src/backend/optimizer/prep/prepjointree.c
@@ -1235,6 +1235,7 @@ pull_up_simple_subquery(PlannerInfo *root, Node *jtnode, RangeTblEntry *rte,
 				case RTE_CTE:
 				case RTE_NAMEDTUPLESTORE:
 				case RTE_RESULT:
+				case RTE_GROUP:
 					/* these can't contain any lateral references */
 					break;
 			}
@@ -2218,7 +2219,8 @@ perform_pullup_replace_vars(PlannerInfo *root,
 	}
 
 	/*
-	 * Replace references in the joinaliasvars lists of join RTEs.
+	 * Replace references in the joinaliasvars lists of join RTEs and the
+	 * groupexprs list of group RTE.
 	 */
 	foreach(lc, parse->rtable)
 	{
@@ -2228,6 +2230,10 @@ perform_pullup_replace_vars(PlannerInfo *root,
 			otherrte->joinaliasvars = (List *)
 				pullup_replace_vars((Node *) otherrte->joinaliasvars,
 									rvcontext);
+		else if (otherrte->rtekind == RTE_GROUP)
+			otherrte->groupexprs = (List *)
+				pullup_replace_vars((Node *) otherrte->groupexprs,
+									rvcontext);
 	}
 }
 
@@ -2293,6 +2299,7 @@ replace_vars_in_jointree(Node *jtnode,
 					case RTE_CTE:
 					case RTE_NAMEDTUPLESTORE:
 					case RTE_RESULT:
+					case RTE_GROUP:
 						/* these shouldn't be marked LATERAL */
 						Assert(false);
 						break;
diff --git a/src/backend/optimizer/util/var.c b/src/backend/optimizer/util/var.c
index 844fc30978..9e93370e6c 100644
--- a/src/backend/optimizer/util/var.c
+++ b/src/backend/optimizer/util/var.c
@@ -81,6 +81,8 @@ static bool pull_var_clause_walker(Node *node,
 								   pull_var_clause_context *context);
 static Node *flatten_join_alias_vars_mutator(Node *node,
 											 flatten_join_alias_vars_context *context);
+static Node *flatten_group_exprs_mutator(Node *node,
+										 flatten_join_alias_vars_context *context);
 static Node *add_nullingrels_if_needed(PlannerInfo *root, Node *newnode,
 									   Var *oldvar);
 static bool is_standard_join_alias_expression(Node *newnode, Var *oldvar);
@@ -902,6 +904,151 @@ flatten_join_alias_vars_mutator(Node *node,
 								   (void *) context);
 }
 
+/*
+ * flatten_group_exprs
+ *	  Replace Vars that reference GROUP outputs with the underlying grouping
+ *	  expressions.
+ *
+ * TODO we need to preserve any varnullingrels info attached to the group Vars
+ * we're replacing.
+ */
+Node *
+flatten_group_exprs(PlannerInfo *root, Query *query, Node *node)
+{
+	flatten_join_alias_vars_context context;
+
+	/*
+	 * We do not expect this to be applied to the whole Query, only to
+	 * expressions or LATERAL subqueries.  Hence, if the top node is a Query,
+	 * it's okay to immediately increment sublevels_up.
+	 */
+	Assert(node != (Node *) query);
+
+	context.root = root;
+	context.query = query;
+	context.sublevels_up = 0;
+	/* flag whether grouping expressions could possibly contain SubLinks */
+	context.possible_sublink = query->hasSubLinks;
+	/* if hasSubLinks is already true, no need to work hard */
+	context.inserted_sublink = query->hasSubLinks;
+
+	return flatten_group_exprs_mutator(node, &context);
+}
+
+static Node *
+flatten_group_exprs_mutator(Node *node,
+							flatten_join_alias_vars_context *context)
+{
+	if (node == NULL)
+		return NULL;
+	if (IsA(node, Var))
+	{
+		Var		   *var = (Var *) node;
+		RangeTblEntry *rte;
+		Node	   *newvar;
+
+		/* No change unless Var belongs to the GROUP of the target level */
+		if (var->varlevelsup != context->sublevels_up)
+			return node;		/* no need to copy, really */
+		rte = rt_fetch(var->varno, context->query->rtable);
+		if (rte->rtekind != RTE_GROUP)
+			return node;
+
+		/* Expand group exprs reference */
+		Assert(var->varattno > 0);
+		newvar = (Node *) list_nth(rte->groupexprs, var->varattno - 1);
+		Assert(newvar != NULL);
+		newvar = copyObject(newvar);
+
+		/*
+		 * If we are expanding an expr carried down from an upper query, must
+		 * adjust its varlevelsup fields.
+		 */
+		if (context->sublevels_up != 0)
+			IncrementVarSublevelsUp(newvar, context->sublevels_up, 0);
+
+		/* Preserve original Var's location, if possible */
+		if (IsA(newvar, Var))
+			((Var *) newvar)->location = var->location;
+
+		/* Detect if we are adding a sublink to query */
+		if (context->possible_sublink && !context->inserted_sublink)
+			context->inserted_sublink = checkExprHasSubLink(newvar);
+
+		/*
+		 * TODO var->varnullingrels might have the nullingrel bit that
+		 * references RTE_GROUP.  We're supposed to add it to the replacement
+		 * expression.
+		 *
+		 * Maybe we can do something like add_nullingrels_if_needed().
+		 */
+		return newvar;
+	}
+
+	if (IsA(node, Aggref))
+	{
+		Aggref	   *agg = (Aggref *) node;
+
+		if ((int) agg->agglevelsup == context->sublevels_up)
+		{
+			/*
+			 * If we find an aggregate call of the original level, do not
+			 * recurse into its normal arguments, ORDER BY arguments, or
+			 * filter; there are no grouped vars there.  But we should check
+			 * direct arguments as though they weren't in an aggregate.
+			 */
+			agg = copyObject(agg);
+			agg->aggdirectargs = (List *)
+				flatten_group_exprs_mutator((Node *) agg->aggdirectargs, context);
+
+			return (Node *) agg;
+		}
+
+		/*
+		 * We can skip recursing into aggregates of higher levels altogether,
+		 * since they could not possibly contain Vars of concern to us (see
+		 * transformAggregateCall).  We do need to look at aggregates of lower
+		 * levels, however.
+		 */
+		if ((int) agg->agglevelsup > context->sublevels_up)
+			return node;
+	}
+
+	if (IsA(node, GroupingFunc))
+	{
+		GroupingFunc *grp = (GroupingFunc *) node;
+
+		/*
+		 * If we find a GroupingFunc node of the original or higher level, do
+		 * not recurse into its arguments; there are no grouped vars there.
+		 */
+		if ((int) grp->agglevelsup >= context->sublevels_up)
+			return node;
+	}
+
+	if (IsA(node, Query))
+	{
+		/* Recurse into RTE subquery or not-yet-planned sublink subquery */
+		Query	   *newnode;
+		bool		save_inserted_sublink;
+
+		context->sublevels_up++;
+		save_inserted_sublink = context->inserted_sublink;
+		context->inserted_sublink = ((Query *) node)->hasSubLinks;
+		newnode = query_tree_mutator((Query *) node,
+									 flatten_group_exprs_mutator,
+									 (void *) context,
+									 QTW_IGNORE_GROUPEXPRS);
+		newnode->hasSubLinks |= context->inserted_sublink;
+		context->inserted_sublink = save_inserted_sublink;
+		context->sublevels_up--;
+		return (Node *) newnode;
+	}
+
+	return expression_tree_mutator(node, flatten_group_exprs_mutator,
+								   (void *) context);
+}
+
 /*
  * Add oldvar's varnullingrels, if any, to a flattened join alias expression.
  * The newnode has been copied, so we can modify it freely.
diff --git a/src/backend/parser/parse_agg.c b/src/backend/parser/parse_agg.c
index bee7d8346a..68858e6d7b 100644
--- a/src/backend/parser/parse_agg.c
+++ b/src/backend/parser/parse_agg.c
@@ -26,6 +26,7 @@
 #include "parser/parse_clause.h"
 #include "parser/parse_coerce.h"
 #include "parser/parse_expr.h"
+#include "parser/parse_relation.h"
 #include "parser/parsetree.h"
 #include "rewrite/rewriteManip.h"
 #include "utils/builtins.h"
@@ -47,11 +48,12 @@ typedef struct
 	bool		hasJoinRTEs;
 	List	   *groupClauses;
 	List	   *groupClauseCommonVars;
+	List	   *gset_common;
 	bool		have_non_var_grouping;
 	List	  **func_grouped_rels;
 	int			sublevels_up;
 	bool		in_agg_direct_args;
-} check_ungrouped_columns_context;
+} substitute_grouped_columns_context;
 
 static int	check_agg_arguments(ParseState *pstate,
 								List *directargs,
@@ -59,17 +61,20 @@ static int	check_agg_arguments(ParseState *pstate,
 								Expr *filter);
 static bool check_agg_arguments_walker(Node *node,
 									   check_agg_arguments_context *context);
-static void check_ungrouped_columns(Node *node, ParseState *pstate, Query *qry,
-									List *groupClauses, List *groupClauseCommonVars,
-									bool have_non_var_grouping,
-									List **func_grouped_rels);
-static bool check_ungrouped_columns_walker(Node *node,
-										   check_ungrouped_columns_context *context);
+static Node *substitute_grouped_columns(Node *node, ParseState *pstate, Query *qry,
+										List *groupClauses, List *groupClauseCommonVars,
+										List *gset_common,
+										bool have_non_var_grouping,
+										List **func_grouped_rels);
+static Node *substitute_grouped_columns_mutator(Node *node,
+												substitute_grouped_columns_context *context);
 static void finalize_grouping_exprs(Node *node, ParseState *pstate, Query *qry,
 									List *groupClauses, bool hasJoinRTEs,
 									bool have_non_var_grouping);
 static bool finalize_grouping_exprs_walker(Node *node,
-										   check_ungrouped_columns_context *context);
+										   substitute_grouped_columns_context *context);
+static Var *buildGroupedVar(Node *node, int attnum, Index ressortgroupref,
+							substitute_grouped_columns_context *context);
 static void check_agglevels_and_constraints(ParseState *pstate, Node *expr);
 static List *expand_groupingset_node(GroupingSet *gs);
 static Node *make_agg_arg(Oid argtype, Oid argcollation);
@@ -1156,7 +1161,7 @@ parseCheckAggregates(ParseState *pstate, Query *qry)
 
 	/*
 	 * Build a list of the acceptable GROUP BY expressions for use by
-	 * check_ungrouped_columns().
+	 * substitute_grouped_columns().
 	 *
 	 * We get the TLE, not just the expr, because GROUPING wants to know the
 	 * sortgroupref.
@@ -1206,10 +1211,22 @@ parseCheckAggregates(ParseState *pstate, Query *qry)
 		{
 			groupClauseCommonVars = lappend(groupClauseCommonVars, tle->expr);
 		}
+
 	}
 
 	/*
-	 * Check the targetlist and HAVING clause for ungrouped variables.
+	 * Now build an RTE and nsitem for the result of the grouping step.
+	 */
+	pstate->p_grouping_nsitem =
+		addRangeTableEntryForGroup(pstate, groupClauses);
+
+	qry->rtable = pstate->p_rtable;
+	qry->hasGroupRTE = true;
+
+	/*
+	 * Replace grouped variables in the targetlist and HAVING clause with Vars
+	 * that reference the GROUP RTE.  Emit an error message if we find any
+	 * ungrouped variables.
 	 *
 	 * Note: because we check resjunk tlist elements as well as regular ones,
 	 * this will also find ungrouped variables that came from ORDER BY and
@@ -1225,10 +1242,12 @@ parseCheckAggregates(ParseState *pstate, Query *qry)
 							have_non_var_grouping);
 	if (hasJoinRTEs)
 		clause = flatten_join_alias_vars(NULL, qry, clause);
-	check_ungrouped_columns(clause, pstate, qry,
-							groupClauses, groupClauseCommonVars,
-							have_non_var_grouping,
-							&func_grouped_rels);
+	qry->targetList = (List *)
+		substitute_grouped_columns(clause, pstate, qry,
+								   groupClauses, groupClauseCommonVars,
+								   gset_common,
+								   have_non_var_grouping,
+								   &func_grouped_rels);
 
 	clause = (Node *) qry->havingQual;
 	finalize_grouping_exprs(clause, pstate, qry,
@@ -1236,10 +1255,12 @@ parseCheckAggregates(ParseState *pstate, Query *qry)
 							have_non_var_grouping);
 	if (hasJoinRTEs)
 		clause = flatten_join_alias_vars(NULL, qry, clause);
-	check_ungrouped_columns(clause, pstate, qry,
-							groupClauses, groupClauseCommonVars,
-							have_non_var_grouping,
-							&func_grouped_rels);
+	qry->havingQual =
+		substitute_grouped_columns(clause, pstate, qry,
+								   groupClauses, groupClauseCommonVars,
+								   gset_common,
+								   have_non_var_grouping,
+								   &func_grouped_rels);
 
 	/*
 	 * Per spec, aggregates can't appear in a recursive term.
@@ -1253,14 +1274,16 @@ parseCheckAggregates(ParseState *pstate, Query *qry)
 }
 
 /*
- * check_ungrouped_columns -
- *	  Scan the given expression tree for ungrouped variables (variables
- *	  that are not listed in the groupClauses list and are not within
- *	  the arguments of aggregate functions).  Emit a suitable error message
- *	  if any are found.
+ * substitute_grouped_columns -
+ *	  Scan the given expression tree for grouped variables (variables that
+ *	  are listed in the groupClauses list) and replace them with Vars that
+ *	  reference the GROUP RTE.  Emit a suitable error message if any
+ *	  ungrouped variables (variables that are not listed in the groupClauses
+ *	  list and are not within the arguments of aggregate functions) are
+ *	  found.
  *
  * NOTE: we assume that the given clause has been transformed suitably for
- * parser output.  This means we can use expression_tree_walker.
+ * parser output.  This means we can use expression_tree_mutator.
  *
  * NOTE: we recognize grouping expressions in the main query, but only
  * grouping Vars in subqueries.  For example, this will be rejected,
@@ -1273,37 +1296,39 @@ parseCheckAggregates(ParseState *pstate, Query *qry)
  * This appears to require a whole custom version of equal(), which is
  * way more pain than the feature seems worth.
  */
-static void
-check_ungrouped_columns(Node *node, ParseState *pstate, Query *qry,
-						List *groupClauses, List *groupClauseCommonVars,
-						bool have_non_var_grouping,
-						List **func_grouped_rels)
+static Node *
+substitute_grouped_columns(Node *node, ParseState *pstate, Query *qry,
+						   List *groupClauses, List *groupClauseCommonVars,
+						   List *gset_common,
+						   bool have_non_var_grouping,
+						   List **func_grouped_rels)
 {
-	check_ungrouped_columns_context context;
+	substitute_grouped_columns_context context;
 
 	context.pstate = pstate;
 	context.qry = qry;
 	context.hasJoinRTEs = false;	/* assume caller flattened join Vars */
 	context.groupClauses = groupClauses;
 	context.groupClauseCommonVars = groupClauseCommonVars;
+	context.gset_common = gset_common;
 	context.have_non_var_grouping = have_non_var_grouping;
 	context.func_grouped_rels = func_grouped_rels;
 	context.sublevels_up = 0;
 	context.in_agg_direct_args = false;
-	check_ungrouped_columns_walker(node, &context);
+	return substitute_grouped_columns_mutator(node, &context);
 }
 
-static bool
-check_ungrouped_columns_walker(Node *node,
-							   check_ungrouped_columns_context *context)
+static Node *
+substitute_grouped_columns_mutator(Node *node,
+								   substitute_grouped_columns_context *context)
 {
 	ListCell   *gl;
 
 	if (node == NULL)
-		return false;
+		return NULL;
 	if (IsA(node, Const) ||
 		IsA(node, Param))
-		return false;			/* constants are always acceptable */
+		return node;			/* constants are always acceptable */
 
 	if (IsA(node, Aggref))
 	{
@@ -1314,19 +1339,21 @@ check_ungrouped_columns_walker(Node *node,
 			/*
 			 * If we find an aggregate call of the original level, do not
 			 * recurse into its normal arguments, ORDER BY arguments, or
-			 * filter; ungrouped vars there are not an error.  But we should
-			 * check direct arguments as though they weren't in an aggregate.
-			 * We set a special flag in the context to help produce a useful
+			 * filter; grouped vars there do not need to be replaced and
+			 * ungrouped vars there are not an error.  But we should check
+			 * direct arguments as though they weren't in an aggregate.  We
+			 * set a special flag in the context to help produce a useful
 			 * error message for ungrouped vars in direct arguments.
 			 */
-			bool		result;
+			agg = copyObject(agg);
 
 			Assert(!context->in_agg_direct_args);
 			context->in_agg_direct_args = true;
-			result = check_ungrouped_columns_walker((Node *) agg->aggdirectargs,
-													context);
+			agg->aggdirectargs = (List *)
+				substitute_grouped_columns_mutator((Node *) agg->aggdirectargs,
+												   context);
 			context->in_agg_direct_args = false;
-			return result;
+			return (Node *) agg;
 		}
 
 		/*
@@ -1336,7 +1363,7 @@ check_ungrouped_columns_walker(Node *node,
 		 * levels, however.
 		 */
 		if ((int) agg->agglevelsup > context->sublevels_up)
-			return false;
+			return node;
 	}
 
 	if (IsA(node, GroupingFunc))
@@ -1346,7 +1373,7 @@ check_ungrouped_columns_walker(Node *node,
 		/* handled GroupingFunc separately, no need to recheck at this level */
 
 		if ((int) grp->agglevelsup >= context->sublevels_up)
-			return false;
+			return node;
 	}
 
 	/*
@@ -1358,12 +1385,19 @@ check_ungrouped_columns_walker(Node *node,
 	 */
 	if (context->have_non_var_grouping && context->sublevels_up == 0)
 	{
+		int attnum = 0;
 		foreach(gl, context->groupClauses)
 		{
-			TargetEntry *tle = lfirst(gl);
+			TargetEntry *tle = (TargetEntry *) lfirst(gl);
 
+			attnum++;
 			if (equal(node, tle->expr))
-				return false;	/* acceptable, do not descend more */
+			{
+				/* acceptable, replace it with a GROUP Var */
+				return (Node *) buildGroupedVar(node, attnum,
+												tle->ressortgroupref,
+												context);
+			}
 		}
 	}
 
@@ -1380,22 +1414,30 @@ check_ungrouped_columns_walker(Node *node,
 		char	   *attname;
 
 		if (var->varlevelsup != context->sublevels_up)
-			return false;		/* it's not local to my query, ignore */
+			return node;		/* it's not local to my query, ignore */
 
 		/*
 		 * Check for a match, if we didn't do it above.
 		 */
 		if (!context->have_non_var_grouping || context->sublevels_up != 0)
 		{
+			int attnum = 0;
 			foreach(gl, context->groupClauses)
 			{
-				Var		   *gvar = (Var *) ((TargetEntry *) lfirst(gl))->expr;
+				TargetEntry *tle = (TargetEntry *) lfirst(gl);
+				Var		   *gvar = (Var *) tle->expr;
 
+				attnum++;
 				if (IsA(gvar, Var) &&
 					gvar->varno == var->varno &&
 					gvar->varattno == var->varattno &&
 					gvar->varlevelsup == 0)
-					return false;	/* acceptable, we're okay */
+				{
+					/* acceptable, replace it with a GROUP Var */
+					return (Node *) buildGroupedVar(node, attnum,
+													tle->ressortgroupref,
+													context);
+				}
 			}
 		}
 
@@ -1416,7 +1458,7 @@ check_ungrouped_columns_walker(Node *node,
 		 * the constraintDeps list.
 		 */
 		if (list_member_int(*context->func_grouped_rels, var->varno))
-			return false;		/* previously proven acceptable */
+			return node;		/* previously proven acceptable */
 
 		Assert(var->varno > 0 &&
 			   (int) var->varno <= list_length(context->pstate->p_rtable));
@@ -1431,7 +1473,7 @@ check_ungrouped_columns_walker(Node *node,
 			{
 				*context->func_grouped_rels =
 					lappend_int(*context->func_grouped_rels, var->varno);
-				return false;	/* acceptable */
+				return node;	/* acceptable */
 			}
 		}
 
@@ -1456,18 +1498,18 @@ check_ungrouped_columns_walker(Node *node,
 	if (IsA(node, Query))
 	{
 		/* Recurse into subselects */
-		bool		result;
+		Query	   *newnode;
 
 		context->sublevels_up++;
-		result = query_tree_walker((Query *) node,
-								   check_ungrouped_columns_walker,
-								   (void *) context,
-								   0);
+		newnode = query_tree_mutator((Query *) node,
+									 substitute_grouped_columns_mutator,
+									 (void *) context,
+									 0);
 		context->sublevels_up--;
-		return result;
+		return (Node *) newnode;
 	}
-	return expression_tree_walker(node, check_ungrouped_columns_walker,
-								  (void *) context);
+	return expression_tree_mutator(node, substitute_grouped_columns_mutator,
+								   (void *) context);
 }
 
 /*
@@ -1475,9 +1517,9 @@ check_ungrouped_columns_walker(Node *node,
  *	  Scan the given expression tree for GROUPING() and related calls,
  *	  and validate and process their arguments.
  *
- * This is split out from check_ungrouped_columns above because it needs
+ * This is split out from substitute_grouped_columns above because it needs
  * to modify the nodes (which it does in-place, not via a mutator) while
- * check_ungrouped_columns may see only a copy of the original thanks to
+ * substitute_grouped_columns may see only a copy of the original thanks to
  * flattening of join alias vars. So here, we flatten each individual
  * GROUPING argument as we see it before comparing it.
  */
@@ -1486,7 +1528,7 @@ finalize_grouping_exprs(Node *node, ParseState *pstate, Query *qry,
 						List *groupClauses, bool hasJoinRTEs,
 						bool have_non_var_grouping)
 {
-	check_ungrouped_columns_context context;
+	substitute_grouped_columns_context context;
 
 	context.pstate = pstate;
 	context.qry = qry;
@@ -1502,7 +1544,7 @@ finalize_grouping_exprs(Node *node, ParseState *pstate, Query *qry,
 
 static bool
 finalize_grouping_exprs_walker(Node *node,
-							   check_ungrouped_columns_context *context)
+							   substitute_grouped_columns_context *context)
 {
 	ListCell   *gl;
 
@@ -1643,6 +1685,37 @@ finalize_grouping_exprs_walker(Node *node,
 								  (void *) context);
 }
 
+/*
+ * buildGroupedVar -
+ *	  build a Var node that references the GROUP RTE
+ */
+static Var *
+buildGroupedVar(Node *node, int attnum, Index ressortgroupref,
+				substitute_grouped_columns_context *context)
+{
+	Var		   *var;
+	ParseNamespaceItem *grouping_nsitem = context->pstate->p_grouping_nsitem;
+	ParseNamespaceColumn *nscol = grouping_nsitem->p_nscolumns + attnum - 1;
+
+	Assert(nscol->p_varno == grouping_nsitem->p_rtindex);
+	var = makeVar(nscol->p_varno,
+				  nscol->p_varattno,
+				  nscol->p_vartype,
+				  nscol->p_vartypmod,
+				  nscol->p_varcollid,
+				  context->sublevels_up);
+	/* makeVar doesn't offer parameters for these, so set by hand: */
+	var->varnosyn = nscol->p_varnosyn;
+	var->varattnosyn = nscol->p_varattnosyn;
+
+	if (context->qry->groupingSets &&
+		!list_member_int(context->gset_common, ressortgroupref))
+		var->varnullingrels =
+			bms_add_member(var->varnullingrels, grouping_nsitem->p_rtindex);
+
+	return var;
+}
+
 
 /*
  * Given a GroupingSet node, expand it and return a list of lists.
diff --git a/src/backend/parser/parse_relation.c b/src/backend/parser/parse_relation.c
index 2f64eaf0e3..6947638425 100644
--- a/src/backend/parser/parse_relation.c
+++ b/src/backend/parser/parse_relation.c
@@ -2557,6 +2557,79 @@ addRangeTableEntryForENR(ParseState *pstate,
 									tupdesc);
 }
 
+/*
+ * Add an entry for grouping step to the pstate's range table (p_rtable).
+ * Then, construct and return a ParseNamespaceItem for the new RTE.
+ */
+ParseNamespaceItem *
+addRangeTableEntryForGroup(ParseState *pstate,
+						   List *groupClauses)
+{
+	RangeTblEntry *rte = makeNode(RangeTblEntry);
+	Alias	   *eref;
+	List	   *groupexprs;
+	List	   *coltypes,
+			   *coltypmods,
+			   *colcollations;
+	ListCell   *lc;
+	ParseNamespaceItem *nsitem;
+
+	Assert(pstate != NULL);
+
+	rte->rtekind = RTE_GROUP;
+	rte->alias = NULL;
+
+	eref = makeAlias("*GROUP*", NIL);
+
+	/* fill in any unspecified alias columns, and extract column type info */
+	groupexprs = NIL;
+	coltypes = coltypmods = colcollations = NIL;
+	foreach(lc, groupClauses)
+	{
+		TargetEntry *te = (TargetEntry *) lfirst(lc);
+		char	   *colname = te->resname ? pstrdup(te->resname) : "unamed_col";
+
+		eref->colnames = lappend(eref->colnames, makeString(colname));
+
+		groupexprs = lappend(groupexprs, copyObject(te->expr));
+
+		coltypes = lappend_oid(coltypes,
+							   exprType((Node *) te->expr));
+		coltypmods = lappend_int(coltypmods,
+								 exprTypmod((Node *) te->expr));
+		colcollations = lappend_oid(colcollations,
+									exprCollation((Node *) te->expr));
+	}
+
+	rte->eref = eref;
+	rte->groupexprs = groupexprs;
+
+	/*
+	 * Set flags.
+	 *
+	 * The grouping step is never checked for access rights, so no need to
+	 * perform addRTEPermissionInfo().
+	 */
+	rte->lateral = false;
+	rte->inFromCl = false;
+
+	/*
+	 * Add completed RTE to pstate's range table list, so that we know its
+	 * index.  But we don't add it to the join list --- caller must do that if
+	 * appropriate.
+	 */
+	pstate->p_rtable = lappend(pstate->p_rtable, rte);
+
+	/*
+	 * Build a ParseNamespaceItem, but don't add it to the pstate's namespace
+	 * list --- caller must do that if appropriate.
+	 */
+	nsitem = buildNSItemFromLists(rte, list_length(pstate->p_rtable),
+								  coltypes, coltypmods, colcollations);
+
+	return nsitem;
+}
+
 
 /*
  * Has the specified refname been selected FOR UPDATE/FOR SHARE?
@@ -3003,6 +3076,7 @@ expandRTE(RangeTblEntry *rte, int rtindex, int sublevels_up,
 			}
 			break;
 		case RTE_RESULT:
+		case RTE_GROUP:
 			/* These expose no columns, so nothing to do */
 			break;
 		default:
@@ -3317,10 +3391,11 @@ get_rte_attribute_is_dropped(RangeTblEntry *rte, AttrNumber attnum)
 		case RTE_TABLEFUNC:
 		case RTE_VALUES:
 		case RTE_CTE:
+		case RTE_GROUP:
 
 			/*
-			 * Subselect, Table Functions, Values, CTE RTEs never have dropped
-			 * columns
+			 * Subselect, Table Functions, Values, CTE, GROUP RTEs never have
+			 * dropped columns
 			 */
 			result = false;
 			break;
diff --git a/src/backend/parser/parse_target.c b/src/backend/parser/parse_target.c
index ee6fcd0503..1f8edc05c9 100644
--- a/src/backend/parser/parse_target.c
+++ b/src/backend/parser/parse_target.c
@@ -380,6 +380,7 @@ markTargetListOrigin(ParseState *pstate, TargetEntry *tle,
 		case RTE_TABLEFUNC:
 		case RTE_NAMEDTUPLESTORE:
 		case RTE_RESULT:
+		case RTE_GROUP:
 			/* not a simple relation, leave it unmarked */
 			break;
 		case RTE_CTE:
@@ -1579,6 +1580,7 @@ expandRecordVariable(ParseState *pstate, Var *var, int levelsup)
 		case RTE_VALUES:
 		case RTE_NAMEDTUPLESTORE:
 		case RTE_RESULT:
+		case RTE_GROUP:
 
 			/*
 			 * This case should not occur: a column of a table, values list,
diff --git a/src/backend/utils/adt/ruleutils.c b/src/backend/utils/adt/ruleutils.c
index 9618619762..9b571b54cb 100644
--- a/src/backend/utils/adt/ruleutils.c
+++ b/src/backend/utils/adt/ruleutils.c
@@ -5433,11 +5433,28 @@ get_query_def(Query *query, StringInfo buf, List *parentnamespace,
 {
 	deparse_context context;
 	deparse_namespace dpns;
+	int			rtable_size;
 
 	/* Guard against excessively long or deeply-nested queries */
 	CHECK_FOR_INTERRUPTS();
 	check_stack_depth();
 
+	rtable_size = query->hasGroupRTE ?
+				  list_length(query->rtable) - 1 :
+				  list_length(query->rtable);
+
+	/*
+	 * Replace any Vars in the query's targetlist and havingQual that reference
+	 * GROUP outputs with the underlying grouping expressions.
+	 */
+	if (query->hasGroupRTE)
+	{
+		query->targetList = (List *)
+			flatten_group_exprs(NULL, query, (Node *) query->targetList);
+		query->havingQual =
+			flatten_group_exprs(NULL, query, query->havingQual);
+	}
+
 	/*
 	 * Before we begin to examine the query, acquire locks on referenced
 	 * relations, and fix up deleted columns in JOIN RTEs.  This ensures
@@ -5454,7 +5471,7 @@ get_query_def(Query *query, StringInfo buf, List *parentnamespace,
 	context.windowClause = NIL;
 	context.windowTList = NIL;
 	context.varprefix = (parentnamespace != NIL ||
-						 list_length(query->rtable) != 1);
+						 rtable_size != 1);
 	context.prettyFlags = prettyFlags;
 	context.wrapColumn = wrapColumn;
 	context.indentLevel = startIndent;
@@ -7838,6 +7855,7 @@ get_name_for_var_field(Var *var, int fieldno,
 		case RTE_VALUES:
 		case RTE_NAMEDTUPLESTORE:
 		case RTE_RESULT:
+		case RTE_GROUP:
 
 			/*
 			 * This case should not occur: a column of a table, values list,
diff --git a/src/include/commands/explain.h b/src/include/commands/explain.h
index 9b8b351d9a..35be084869 100644
--- a/src/include/commands/explain.h
+++ b/src/include/commands/explain.h
@@ -67,6 +67,7 @@ typedef struct ExplainState
 	List	   *deparse_cxt;	/* context list for deparsing expressions */
 	Bitmapset  *printed_subplans;	/* ids of SubPlans we've printed */
 	bool		hide_workers;	/* set if we find an invisible Gather */
+	int			rtable_size;	/* length of rtable excluding GROUP entries */
 	/* state related to the current plan node */
 	ExplainWorkersState *workers_state; /* needed if parallel plan */
 } ExplainState;
diff --git a/src/include/nodes/nodeFuncs.h b/src/include/nodes/nodeFuncs.h
index eaba59bed8..1f0de5b3d8 100644
--- a/src/include/nodes/nodeFuncs.h
+++ b/src/include/nodes/nodeFuncs.h
@@ -31,6 +31,8 @@ struct PlanState;				/* avoid including execnodes.h too */
 #define QTW_DONT_COPY_QUERY			0x40	/* do not copy top Query */
 #define QTW_EXAMINE_SORTGROUP		0x80	/* include SortGroupClause lists */
 
+#define QTW_IGNORE_GROUPEXPRS		0x100	/* GROUP expressions lists */
+
 /* callback function for check_functions_in_node */
 typedef bool (*check_function_callback) (Oid func_id, void *context);
 
diff --git a/src/include/nodes/parsenodes.h b/src/include/nodes/parsenodes.h
index ddfed02db2..eb4054bbe3 100644
--- a/src/include/nodes/parsenodes.h
+++ b/src/include/nodes/parsenodes.h
@@ -160,6 +160,8 @@ typedef struct Query
 	bool		hasForUpdate pg_node_attr(query_jumble_ignore);
 	/* rewriter has applied some RLS policy */
 	bool		hasRowSecurity pg_node_attr(query_jumble_ignore);
+	/* parser has added a GROUP RTE */
+	bool		hasGroupRTE pg_node_attr(query_jumble_ignore);
 	/* is a RETURN statement */
 	bool		isReturn pg_node_attr(query_jumble_ignore);
 
@@ -1036,6 +1038,7 @@ typedef enum RTEKind
 	RTE_RESULT,					/* RTE represents an empty FROM clause; such
 								 * RTEs are added by the planner, they're not
 								 * present during parsing or rewriting */
+	RTE_GROUP,					/* the grouping step */
 } RTEKind;
 
 typedef struct RangeTblEntry
@@ -1242,6 +1245,12 @@ typedef struct RangeTblEntry
 	/* estimated or actual from caller */
 	Cardinality enrtuples pg_node_attr(query_jumble_ignore);
 
+	/*
+	 * Fields valid for GROUP RTEs (else NULL/zero):
+	 */
+	/* list of expressions grouped on */
+	List	   *groupexprs pg_node_attr(query_jumble_ignore);
+
 	/*
 	 * Fields valid in all RTEs:
 	 */
diff --git a/src/include/nodes/pathnodes.h b/src/include/nodes/pathnodes.h
index 14ef296ab7..c082693e7c 100644
--- a/src/include/nodes/pathnodes.h
+++ b/src/include/nodes/pathnodes.h
@@ -505,6 +505,11 @@ struct PlannerInfo
 	/* true if planning a recursive WITH item */
 	bool		hasRecursion;
 
+	/*
+	 * The rangetable index for the GROUP RTE, or 0 if there is no GROUP RTE.
+	 */
+	int			group_rtindex;
+
 	/*
 	 * Information about aggregates. Filled by preprocess_aggrefs().
 	 */
diff --git a/src/include/optimizer/optimizer.h b/src/include/optimizer/optimizer.h
index 7b63c5cf71..93e3dc719d 100644
--- a/src/include/optimizer/optimizer.h
+++ b/src/include/optimizer/optimizer.h
@@ -201,5 +201,6 @@ extern bool contain_vars_of_level(Node *node, int levelsup);
 extern int	locate_var_of_level(Node *node, int levelsup);
 extern List *pull_var_clause(Node *node, int flags);
 extern Node *flatten_join_alias_vars(PlannerInfo *root, Query *query, Node *node);
+extern Node *flatten_group_exprs(PlannerInfo *root, Query *query, Node *node);
 
 #endif							/* OPTIMIZER_H */
diff --git a/src/include/parser/parse_node.h b/src/include/parser/parse_node.h
index 5b781d87a9..ef78fd8224 100644
--- a/src/include/parser/parse_node.h
+++ b/src/include/parser/parse_node.h
@@ -237,6 +237,8 @@ struct ParseState
 	ParseParamRefHook p_paramref_hook;
 	CoerceParamHook p_coerce_param_hook;
 	void	   *p_ref_hook_state;	/* common passthrough link for above */
+
+	ParseNamespaceItem *p_grouping_nsitem;	/* NSItem for grouping, or NULL */
 };
 
 /*
diff --git a/src/include/parser/parse_relation.h b/src/include/parser/parse_relation.h
index bea2da5496..91fd8e243b 100644
--- a/src/include/parser/parse_relation.h
+++ b/src/include/parser/parse_relation.h
@@ -100,6 +100,8 @@ extern ParseNamespaceItem *addRangeTableEntryForCTE(ParseState *pstate,
 extern ParseNamespaceItem *addRangeTableEntryForENR(ParseState *pstate,
 													RangeVar *rv,
 													bool inFromCl);
+extern ParseNamespaceItem *addRangeTableEntryForGroup(ParseState *pstate,
+													  List *groupClauses);
 extern RTEPermissionInfo *addRTEPermissionInfo(List **rteperminfos,
 											   RangeTblEntry *rte);
 extern RTEPermissionInfo *getRTEPermissionInfo(List *rteperminfos,
diff --git a/src/test/regress/expected/groupingsets.out b/src/test/regress/expected/groupingsets.out
index e1f0660810..9c7590e7ba 100644
--- a/src/test/regress/expected/groupingsets.out
+++ b/src/test/regress/expected/groupingsets.out
@@ -2150,4 +2150,53 @@ select (select grouping(v1)) from (values ((select 1))) v(v1) group by v1;
         0
 (1 row)
 
+-- test handling of subqueries in grouping sets
+create temp table gstest5(id integer primary key, v integer);
+insert into gstest5 select i, i from generate_series(1,5)i;
+explain (costs off)
+select grouping((select t1.v from gstest5 t2 where id = t1.id)),
+       (select t1.v from gstest5 t2 where id = t1.id) as s
+from gstest5 t1
+group by grouping sets(v, s)
+order by case when grouping((select t1.v from gstest5 t2 where id = t1.id)) = 0
+              then (select t1.v from gstest5 t2 where id = t1.id)
+              else null end
+         nulls first;
+                                                QUERY PLAN                                                 
+-----------------------------------------------------------------------------------------------------------
+ Sort
+   Sort Key: (CASE WHEN (GROUPING((SubPlan 2)) = 0) THEN ((SubPlan 3)) ELSE NULL::integer END) NULLS FIRST
+   ->  HashAggregate
+         Hash Key: t1.v
+         Hash Key: (SubPlan 3)
+         ->  Seq Scan on gstest5 t1
+               SubPlan 3
+                 ->  Bitmap Heap Scan on gstest5 t2
+                       Recheck Cond: (id = t1.id)
+                       ->  Bitmap Index Scan on gstest5_pkey
+                             Index Cond: (id = t1.id)
+(11 rows)
+
+select grouping((select t1.v from gstest5 t2 where id = t1.id)),
+       (select t1.v from gstest5 t2 where id = t1.id) as s
+from gstest5 t1
+group by grouping sets(v, s)
+order by case when grouping((select t1.v from gstest5 t2 where id = t1.id)) = 0
+              then (select t1.v from gstest5 t2 where id = t1.id)
+              else null end
+         nulls first;
+ grouping | s 
+----------+---
+        1 |  
+        1 |  
+        1 |  
+        1 |  
+        1 |  
+        0 | 1
+        0 | 2
+        0 | 3
+        0 | 4
+        0 | 5
+(10 rows)
+
 -- end
diff --git a/src/test/regress/sql/groupingsets.sql b/src/test/regress/sql/groupingsets.sql
index 90ba27257a..0520e44aeb 100644
--- a/src/test/regress/sql/groupingsets.sql
+++ b/src/test/regress/sql/groupingsets.sql
@@ -589,4 +589,27 @@ explain (costs off)
 select (select grouping(v1)) from (values ((select 1))) v(v1) group by v1;
 select (select grouping(v1)) from (values ((select 1))) v(v1) group by v1;
 
+-- test handling of subqueries in grouping sets
+create temp table gstest5(id integer primary key, v integer);
+insert into gstest5 select i, i from generate_series(1,5)i;
+
+explain (costs off)
+select grouping((select t1.v from gstest5 t2 where id = t1.id)),
+       (select t1.v from gstest5 t2 where id = t1.id) as s
+from gstest5 t1
+group by grouping sets(v, s)
+order by case when grouping((select t1.v from gstest5 t2 where id = t1.id)) = 0
+              then (select t1.v from gstest5 t2 where id = t1.id)
+              else null end
+         nulls first;
+
+select grouping((select t1.v from gstest5 t2 where id = t1.id)),
+       (select t1.v from gstest5 t2 where id = t1.id) as s
+from gstest5 t1
+group by grouping sets(v, s)
+order by case when grouping((select t1.v from gstest5 t2 where id = t1.id)) = 0
+              then (select t1.v from gstest5 t2 where id = t1.id)
+              else null end
+         nulls first;
+
 -- end
-- 
2.34.1

