From cbee5b8164a87196375d236850bd9584cb2fb46e Mon Sep 17 00:00:00 2001
From: David Rowley <dgrowley@gmail.com>
Date: Mon, 3 May 2021 16:31:29 +1200
Subject: [PATCH v2 1/2] Add planner support for ORDER BY aggregates

ORDER BY aggreagtes have, since implemented in Postgres, been executed by
always performing a sort in nodeAgg.c to sort the tuples in the current
group into the correct order before calling the transition function on the
sorted results.  This was not great as often there might be an index that
could have provided pre-sorted input and allowed the transition functions
to be called as the rows come in, rather than having to store them in a
tuplestore in order to sort them later.

Here we get the planner on-board with picking a plan that provides
pre-sorted inputs to ORDER BY aggregates.  Since there can be many ORDER
BY aggregates in any given query level, it's very possible that we can't
find an order that suits all aggregates, so we just pick the first one,
then run through the remainder seeing if any other ones require any more
strict variations on the first aggregate's sort order.  For example:

SELECT agg(a ORDER BY a),agg2(a ORDER BY a,b) ...

would request the sort order to be {a, b} because {a} is a subset of the
sort order of {a,b}, but;

SELECT agg(a ORDER BY a),agg2(a ORDER BY c) ...

would just pick a plan ordered by {a}.

Making DISTINCT aggregates work requires a bit more work.  Those are still
handled the traditional way by performing a sort inside nodeAgg.c
---
 src/backend/executor/execExpr.c               | 27 +++++++--
 src/backend/executor/nodeAgg.c                | 20 ++++++-
 src/backend/nodes/copyfuncs.c                 |  1 +
 src/backend/nodes/equalfuncs.c                |  1 +
 src/backend/nodes/outfuncs.c                  |  1 +
 src/backend/nodes/readfuncs.c                 |  1 +
 src/backend/optimizer/plan/planner.c          | 59 +++++++++++++++++++
 src/include/executor/nodeAgg.h                |  5 ++
 src/include/nodes/primnodes.h                 |  4 ++
 .../regress/expected/partition_aggregate.out  |  8 +--
 10 files changed, 115 insertions(+), 12 deletions(-)

diff --git a/src/backend/executor/execExpr.c b/src/backend/executor/execExpr.c
index 2c8c414a14..a6e9c48f11 100644
--- a/src/backend/executor/execExpr.c
+++ b/src/backend/executor/execExpr.c
@@ -3417,13 +3417,16 @@ ExecBuildAggTrans(AggState *aggstate, AggStatePerPhase phase,
 				scratch.resnull = &state->resnull;
 			}
 			argno++;
+
+			Assert(pertrans->numInputs == argno);
 		}
-		else if (pertrans->numSortCols == 0)
+		else if (!pertrans->aggsortrequired)
 		{
 			ListCell   *arg;
 
 			/*
-			 * Normal transition function without ORDER BY / DISTINCT.
+			 * Normal transition function without ORDER BY / DISTINCT or with
+			 * ORDER BY but the planner has given us pre-sorted input.
 			 */
 			strictargs = trans_fcinfo->args + 1;
 
@@ -3431,6 +3434,13 @@ ExecBuildAggTrans(AggState *aggstate, AggStatePerPhase phase,
 			{
 				TargetEntry *source_tle = (TargetEntry *) lfirst(arg);
 
+				/*
+				 * Don't initialize args for any ORDER BY clause that might
+				 * exist in a presorted aggregate.
+				 */
+				if (argno == pertrans->numTransInputs)
+					break;
+
 				/*
 				 * Start from 1, since the 0th arg will be the transition
 				 * value
@@ -3440,11 +3450,13 @@ ExecBuildAggTrans(AggState *aggstate, AggStatePerPhase phase,
 								&trans_fcinfo->args[argno + 1].isnull);
 				argno++;
 			}
+			Assert(pertrans->numTransInputs == argno);
 		}
 		else if (pertrans->numInputs == 1)
 		{
 			/*
-			 * DISTINCT and/or ORDER BY case, with a single column sorted on.
+			 * Non-presorted DISTINCT and/or ORDER BY case, with a single
+			 * column sorted on.
 			 */
 			TargetEntry *source_tle =
 			(TargetEntry *) linitial(pertrans->aggref->args);
@@ -3456,11 +3468,14 @@ ExecBuildAggTrans(AggState *aggstate, AggStatePerPhase phase,
 							&state->resnull);
 			strictnulls = &state->resnull;
 			argno++;
+
+			Assert(pertrans->numInputs == argno);
 		}
 		else
 		{
 			/*
-			 * DISTINCT and/or ORDER BY case, with multiple columns sorted on.
+			 * Non-presorted DISTINCT and/or ORDER BY case, with multiple
+			 * columns sorted on.
 			 */
 			Datum	   *values = pertrans->sortslot->tts_values;
 			bool	   *nulls = pertrans->sortslot->tts_isnull;
@@ -3476,8 +3491,8 @@ ExecBuildAggTrans(AggState *aggstate, AggStatePerPhase phase,
 								&values[argno], &nulls[argno]);
 				argno++;
 			}
+			Assert(pertrans->numInputs == argno);
 		}
-		Assert(pertrans->numInputs == argno);
 
 		/*
 		 * For a strict transfn, nothing happens when there's a NULL input; we
@@ -3638,7 +3653,7 @@ ExecBuildAggTransCall(ExprState *state, AggState *aggstate,
 	 * process_ordered_aggregate_{single, multi} and
 	 * advance_transition_function.
 	 */
-	if (pertrans->numSortCols == 0)
+	if (!pertrans->aggsortrequired)
 	{
 		if (pertrans->transtypeByVal)
 		{
diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c
index 914b02ceee..e28d53c17b 100644
--- a/src/backend/executor/nodeAgg.c
+++ b/src/backend/executor/nodeAgg.c
@@ -602,7 +602,7 @@ initialize_aggregate(AggState *aggstate, AggStatePerTrans pertrans,
 	/*
 	 * Start a fresh sort operation for each DISTINCT/ORDER BY aggregate.
 	 */
-	if (pertrans->numSortCols > 0)
+	if (pertrans->aggsortrequired)
 	{
 		/*
 		 * In case of rescan, maybe there could be an uncompleted sort
@@ -1328,7 +1328,7 @@ finalize_aggregates(AggState *aggstate,
 
 		pergroupstate = &pergroup[transno];
 
-		if (pertrans->numSortCols > 0)
+		if (pertrans->aggsortrequired)
 		{
 			Assert(aggstate->aggstrategy != AGG_HASHED &&
 				   aggstate->aggstrategy != AGG_MIXED);
@@ -4220,6 +4220,11 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
 	 * stick them into arrays.  We ignore ORDER BY for an ordered-set agg,
 	 * however; the agg's transfn and finalfn are responsible for that.
 	 *
+	 * When the planner has set the aggpresorted flag, the input to the
+	 * aggregate is already correctly sorted.  For ORDER BY aggregates we can
+	 * simply treat these as normal aggregates.  For DISTINCT aggregates we
+	 * must still handle de-duplication of consecutive non-distinct values.
+	 *
 	 * Note that by construction, if there is a DISTINCT clause then the ORDER
 	 * BY clause is a prefix of it (see transformDistinctClause).
 	 */
@@ -4227,18 +4232,29 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
 	{
 		sortlist = NIL;
 		numSortCols = numDistinctCols = 0;
+		pertrans->aggsortrequired = false;
+	}
+	else if (aggref->aggpresorted)
+	{
+		/* DISTINCT not yet supported for aggpresorted */
+		Assert(aggref->aggdistinct == NIL);
+		sortlist = NIL;
+		numSortCols = numDistinctCols = 0;
+		pertrans->aggsortrequired = false;
 	}
 	else if (aggref->aggdistinct)
 	{
 		sortlist = aggref->aggdistinct;
 		numSortCols = numDistinctCols = list_length(sortlist);
 		Assert(numSortCols >= list_length(aggref->aggorder));
+		pertrans->aggsortrequired = true;
 	}
 	else
 	{
 		sortlist = aggref->aggorder;
 		numSortCols = list_length(sortlist);
 		numDistinctCols = 0;
+		pertrans->aggsortrequired = (numSortCols > 0);
 	}
 
 	pertrans->numSortCols = numSortCols;
diff --git a/src/backend/nodes/copyfuncs.c b/src/backend/nodes/copyfuncs.c
index 6fef067957..6b71209476 100644
--- a/src/backend/nodes/copyfuncs.c
+++ b/src/backend/nodes/copyfuncs.c
@@ -1541,6 +1541,7 @@ _copyAggref(const Aggref *from)
 	COPY_SCALAR_FIELD(aggstar);
 	COPY_SCALAR_FIELD(aggvariadic);
 	COPY_SCALAR_FIELD(aggkind);
+	COPY_SCALAR_FIELD(aggpresorted);
 	COPY_SCALAR_FIELD(agglevelsup);
 	COPY_SCALAR_FIELD(aggsplit);
 	COPY_SCALAR_FIELD(aggno);
diff --git a/src/backend/nodes/equalfuncs.c b/src/backend/nodes/equalfuncs.c
index b9cc7b199c..235115900b 100644
--- a/src/backend/nodes/equalfuncs.c
+++ b/src/backend/nodes/equalfuncs.c
@@ -230,6 +230,7 @@ _equalAggref(const Aggref *a, const Aggref *b)
 	COMPARE_SCALAR_FIELD(aggstar);
 	COMPARE_SCALAR_FIELD(aggvariadic);
 	COMPARE_SCALAR_FIELD(aggkind);
+	COMPARE_SCALAR_FIELD(aggpresorted);
 	COMPARE_SCALAR_FIELD(agglevelsup);
 	COMPARE_SCALAR_FIELD(aggsplit);
 	COMPARE_SCALAR_FIELD(aggno);
diff --git a/src/backend/nodes/outfuncs.c b/src/backend/nodes/outfuncs.c
index e09e4f77fe..74400f3b62 100644
--- a/src/backend/nodes/outfuncs.c
+++ b/src/backend/nodes/outfuncs.c
@@ -1179,6 +1179,7 @@ _outAggref(StringInfo str, const Aggref *node)
 	WRITE_BOOL_FIELD(aggstar);
 	WRITE_BOOL_FIELD(aggvariadic);
 	WRITE_CHAR_FIELD(aggkind);
+	WRITE_BOOL_FIELD(aggpresorted);
 	WRITE_UINT_FIELD(agglevelsup);
 	WRITE_ENUM_FIELD(aggsplit, AggSplit);
 	WRITE_INT_FIELD(aggno);
diff --git a/src/backend/nodes/readfuncs.c b/src/backend/nodes/readfuncs.c
index 3dec0a2508..e46629fb9c 100644
--- a/src/backend/nodes/readfuncs.c
+++ b/src/backend/nodes/readfuncs.c
@@ -655,6 +655,7 @@ _readAggref(void)
 	READ_BOOL_FIELD(aggstar);
 	READ_BOOL_FIELD(aggvariadic);
 	READ_CHAR_FIELD(aggkind);
+	READ_BOOL_FIELD(aggpresorted);
 	READ_UINT_FIELD(agglevelsup);
 	READ_ENUM_FIELD(aggsplit, AggSplit);
 	READ_INT_FIELD(aggno);
diff --git a/src/backend/optimizer/plan/planner.c b/src/backend/optimizer/plan/planner.c
index 1868c4eff4..721a081853 100644
--- a/src/backend/optimizer/plan/planner.c
+++ b/src/backend/optimizer/plan/planner.c
@@ -24,6 +24,7 @@
 #include "access/sysattr.h"
 #include "access/table.h"
 #include "access/xact.h"
+#include "catalog/pg_aggregate.h"
 #include "catalog/pg_constraint.h"
 #include "catalog/pg_inherits.h"
 #include "catalog/pg_proc.h"
@@ -3080,6 +3081,64 @@ standard_qp_callback(PlannerInfo *root, void *extra)
 	else
 		root->group_pathkeys = NIL;
 
+	/* Determine pathkeys for aggregate functions with an ORDER BY */
+	if (parse->groupingSets == NIL && root->numOrderedAggs > 0 &&
+		(qp_extra->groupClause == NIL || root->group_pathkeys))
+	{
+		ListCell   *lc;
+		List	   *pathkeys = NIL;
+		List	   *sortlist;
+
+		foreach(lc, root->agginfos)
+		{
+			AggInfo    *agginfo = (AggInfo *) lfirst(lc);
+			Aggref	   *aggref = agginfo->representative_aggref;
+
+			if (AGGKIND_IS_ORDERED_SET(aggref->aggkind))
+				continue;
+
+			/* DISTINCT aggregates not yet supported by the planner */
+			if (aggref->aggdistinct != NIL)
+				continue;
+
+			if (aggref->aggorder != NIL)
+				sortlist = aggref->aggorder;
+			else
+				continue;
+
+			/*
+			 * Find the pathkeys with the most sorted derivative of the first
+			 * Aggref. For example, if we determine the pathkeys for the first
+			 * Aggref to be {a}, and we find another with {a,b}, then we use
+			 * {a,b} since it's useful for more Aggrefs than just {a}.  We
+			 * currently ignore anything that might have a longer list of
+			 * pathkeys than the first Aggref if it is not contained in the
+			 * pathkeys for the first agg.  We can't practically plan for all
+			 * orders of each Aggref, so this seems like the best compromise.
+			 */
+			if (pathkeys == NIL)
+			{
+				pathkeys = make_pathkeys_for_sortclauses(root, sortlist,
+														 aggref->args);
+				aggref->aggpresorted = true;
+			}
+			else
+			{
+				List	   *pathkeys2 = make_pathkeys_for_sortclauses(root,
+																	  sortlist,
+																	  aggref->args);
+
+				if (pathkeys_contained_in(pathkeys, pathkeys2))
+				{
+					pathkeys = pathkeys2;
+					aggref->aggpresorted = true;
+				}
+			}
+		}
+
+		root->group_pathkeys = list_concat(root->group_pathkeys, pathkeys);
+	}
+
 	/* We consider only the first (bottom) window in pathkeys logic */
 	if (activeWindows != NIL)
 	{
diff --git a/src/include/executor/nodeAgg.h b/src/include/executor/nodeAgg.h
index 398446d11f..bcd0643699 100644
--- a/src/include/executor/nodeAgg.h
+++ b/src/include/executor/nodeAgg.h
@@ -48,6 +48,11 @@ typedef struct AggStatePerTransData
 	 */
 	bool		aggshared;
 
+	/*
+	 * True for ORDER BY aggregates that are not Aggref->aggpresorted
+	 */
+	bool		aggsortrequired;
+
 	/*
 	 * Number of aggregated input columns.  This includes ORDER BY expressions
 	 * in both the plain-agg and ordered-set cases.  Ordered-set direct args
diff --git a/src/include/nodes/primnodes.h b/src/include/nodes/primnodes.h
index 996c3e4016..2f3cc39d4b 100644
--- a/src/include/nodes/primnodes.h
+++ b/src/include/nodes/primnodes.h
@@ -304,6 +304,9 @@ typedef struct Param
  * replaced with a single argument representing the partial-aggregate
  * transition values.
  *
+ * aggpresorted is set by the query planner for ORDER BY aggregates where the
+ * query plan chosen provides presorted input for the executor.
+ *
  * aggsplit indicates the expected partial-aggregation mode for the Aggref's
  * parent plan node.  It's always set to AGGSPLIT_SIMPLE in the parser, but
  * the planner might change it to something else.  We use this mainly as
@@ -335,6 +338,7 @@ typedef struct Aggref
 	bool		aggvariadic;	/* true if variadic arguments have been
 								 * combined into an array last argument */
 	char		aggkind;		/* aggregate kind (see pg_aggregate.h) */
+	bool		aggpresorted;	/* Agg input already sorted */
 	Index		agglevelsup;	/* > 0 if agg belongs to outer query */
 	AggSplit	aggsplit;		/* expected agg-splitting mode of parent Agg */
 	int			aggno;			/* unique ID within the Agg node */
diff --git a/src/test/regress/expected/partition_aggregate.out b/src/test/regress/expected/partition_aggregate.out
index dfa4b036b5..484c94e585 100644
--- a/src/test/regress/expected/partition_aggregate.out
+++ b/src/test/regress/expected/partition_aggregate.out
@@ -367,17 +367,17 @@ SELECT c, sum(b order by a) FROM pagg_tab GROUP BY c ORDER BY 1, 2;
          ->  GroupAggregate
                Group Key: pagg_tab.c
                ->  Sort
-                     Sort Key: pagg_tab.c
+                     Sort Key: pagg_tab.c, pagg_tab.a
                      ->  Seq Scan on pagg_tab_p1 pagg_tab
          ->  GroupAggregate
                Group Key: pagg_tab_1.c
                ->  Sort
-                     Sort Key: pagg_tab_1.c
+                     Sort Key: pagg_tab_1.c, pagg_tab_1.a
                      ->  Seq Scan on pagg_tab_p2 pagg_tab_1
          ->  GroupAggregate
                Group Key: pagg_tab_2.c
                ->  Sort
-                     Sort Key: pagg_tab_2.c
+                     Sort Key: pagg_tab_2.c, pagg_tab_2.a
                      ->  Seq Scan on pagg_tab_p3 pagg_tab_2
 (18 rows)
 
@@ -393,7 +393,7 @@ SELECT a, sum(b order by a) FROM pagg_tab GROUP BY a ORDER BY 1, 2;
    ->  GroupAggregate
          Group Key: pagg_tab.a
          ->  Sort
-               Sort Key: pagg_tab.a
+               Sort Key: pagg_tab.a, pagg_tab.a
                ->  Append
                      ->  Seq Scan on pagg_tab_p1 pagg_tab_1
                      ->  Seq Scan on pagg_tab_p2 pagg_tab_2
-- 
2.30.2

