From 378107f06b562460e4cfefdbc7f1b3c58ad51ab2 Mon Sep 17 00:00:00 2001
From: David Rowley <dgrowley@gmail.com>
Date: Mon, 12 Jul 2021 20:24:08 +1200
Subject: [PATCH v3 2/2] WIP: Add planner support for DISTINCT aggregates

---
 src/backend/executor/execExpr.c               | 24 +++++-
 src/backend/executor/execExprInterp.c         | 82 +++++++++++++++++++
 src/backend/executor/nodeAgg.c                | 21 ++++-
 src/backend/optimizer/plan/planner.c          | 10 +--
 src/include/executor/execExpr.h               | 13 +++
 src/include/executor/nodeAgg.h                |  6 +-
 src/include/nodes/primnodes.h                 |  4 +-
 src/test/regress/expected/aggregates.out      |  2 +-
 .../regress/expected/partition_aggregate.out  | 12 +--
 src/test/regress/expected/tuplesort.out       | 40 +++++----
 10 files changed, 176 insertions(+), 38 deletions(-)

diff --git a/src/backend/executor/execExpr.c b/src/backend/executor/execExpr.c
index a6e9c48f11..0c84e3757a 100644
--- a/src/backend/executor/execExpr.c
+++ b/src/backend/executor/execExpr.c
@@ -3426,7 +3426,8 @@ ExecBuildAggTrans(AggState *aggstate, AggStatePerPhase phase,
 
 			/*
 			 * Normal transition function without ORDER BY / DISTINCT or with
-			 * ORDER BY but the planner has given us pre-sorted input.
+			 * ORDER BY / DISTINCT but the planner has given us pre-sorted
+			 * input.
 			 */
 			strictargs = trans_fcinfo->args + 1;
 
@@ -3514,6 +3515,21 @@ ExecBuildAggTrans(AggState *aggstate, AggStatePerPhase phase,
 										 state->steps_len - 1);
 		}
 
+		/* Handle DISTINCT aggregates which have pre-sorted input */
+		if (pertrans->numDistinctCols > 0 && !pertrans->aggsortrequired)
+		{
+			if (pertrans->numDistinctCols > 1)
+				scratch.opcode = EEOP_AGG_PRESORTED_DISTINCT_MULTI;
+			else
+				scratch.opcode = EEOP_AGG_PRESORTED_DISTINCT_SINGLE;
+
+			scratch.d.agg_presorted_distinctcheck.pertrans = pertrans;
+			scratch.d.agg_presorted_distinctcheck.jumpdistinct = -1;	/* adjust later */
+			ExprEvalPushStep(state, &scratch);
+			adjust_bailout = lappend_int(adjust_bailout,
+										 state->steps_len - 1);
+		}
+
 		/*
 		 * Call transition function (once for each concurrently evaluated
 		 * grouping set). Do so for both sort and hash based computations, as
@@ -3574,6 +3590,12 @@ ExecBuildAggTrans(AggState *aggstate, AggStatePerPhase phase,
 				Assert(as->d.agg_deserialize.jumpnull == -1);
 				as->d.agg_deserialize.jumpnull = state->steps_len;
 			}
+			else if (as->opcode == EEOP_AGG_PRESORTED_DISTINCT_SINGLE ||
+					 as->opcode == EEOP_AGG_PRESORTED_DISTINCT_MULTI)
+			{
+				Assert(as->d.agg_presorted_distinctcheck.jumpdistinct == -1);
+				as->d.agg_presorted_distinctcheck.jumpdistinct = state->steps_len;
+			}
 			else
 				Assert(false);
 		}
diff --git a/src/backend/executor/execExprInterp.c b/src/backend/executor/execExprInterp.c
index eb49817cee..17904bdb7f 100644
--- a/src/backend/executor/execExprInterp.c
+++ b/src/backend/executor/execExprInterp.c
@@ -488,6 +488,8 @@ ExecInterpExpr(ExprState *state, ExprContext *econtext, bool *isnull)
 		&&CASE_EEOP_AGG_PLAIN_TRANS_INIT_STRICT_BYREF,
 		&&CASE_EEOP_AGG_PLAIN_TRANS_STRICT_BYREF,
 		&&CASE_EEOP_AGG_PLAIN_TRANS_BYREF,
+		&&CASE_EEOP_AGG_PRESORTED_DISTINCT_SINGLE,
+		&&CASE_EEOP_AGG_PRESORTED_DISTINCT_MULTI,
 		&&CASE_EEOP_AGG_ORDERED_TRANS_DATUM,
 		&&CASE_EEOP_AGG_ORDERED_TRANS_TUPLE,
 		&&CASE_EEOP_LAST
@@ -1772,6 +1774,86 @@ ExecInterpExpr(ExprState *state, ExprContext *econtext, bool *isnull)
 			EEO_NEXT();
 		}
 
+		EEO_CASE(EEOP_AGG_PRESORTED_DISTINCT_SINGLE)
+		{
+			AggStatePerTrans pertrans = op->d.agg_presorted_distinctcheck.pertrans;
+			Datum		value = pertrans->transfn_fcinfo->args[1].value;
+			bool		isnull = pertrans->transfn_fcinfo->args[1].isnull;
+
+			if (!pertrans->haslast ||
+				pertrans->lastisnull != isnull ||
+				!DatumGetBool(FunctionCall2Coll(&pertrans->equalfnOne,
+												pertrans->aggCollation,
+												pertrans->lastdatum, value)))
+			{
+				if (pertrans->haslast && !pertrans->inputtypeByVal)
+					pfree(DatumGetPointer(pertrans->lastdatum));
+
+				pertrans->haslast = true;
+				if (!isnull)
+				{
+					AggState   *aggstate = castNode(AggState, state->parent);
+
+					/*
+					 * XXX is it worth having a dedicated ByVal version of this
+					 * operation so that we can skip switching memory contexts
+					 * and do a simple assign rather than datumCopy below?
+					 */
+					MemoryContext oldContext;
+
+					oldContext = MemoryContextSwitchTo(aggstate->curaggcontext->ecxt_per_tuple_memory);
+
+					pertrans->lastdatum = datumCopy(value, pertrans->inputtypeByVal, pertrans->inputtypeLen);
+
+					MemoryContextSwitchTo(oldContext);
+				}
+				else
+					pertrans->lastdatum = (Datum) 0;
+				pertrans->lastisnull = isnull;
+				EEO_NEXT();
+			}
+			EEO_JUMP(op->d.agg_presorted_distinctcheck.jumpdistinct);
+		}
+
+		EEO_CASE(EEOP_AGG_PRESORTED_DISTINCT_MULTI)
+		{
+			AggState   *aggstate = castNode(AggState, state->parent);
+			AggStatePerTrans pertrans = op->d.agg_presorted_distinctcheck.pertrans;
+			ExprContext *tmpcontext = aggstate->tmpcontext;
+			int			i;
+
+			/*
+			 * XXX or should we have had these values copied directly into the
+			 * sortslot?  If we did then we'd still need to copy them into the
+			 * transfn_fcinfo->args here if we detect the tuple is distinct
+			 * from the previous tuple.
+			 */
+			for (i = 0; i < pertrans->numTransInputs; i++)
+			{
+				pertrans->sortslot->tts_values[i] = pertrans->transfn_fcinfo->args[i + 1].value;
+				pertrans->sortslot->tts_isnull[i] = pertrans->transfn_fcinfo->args[i + 1].isnull;
+			}
+
+			ExecClearTuple(pertrans->sortslot);
+			pertrans->sortslot->tts_nvalid = pertrans->numInputs;
+			ExecStoreVirtualTuple(pertrans->sortslot);
+
+			tmpcontext->ecxt_outertuple = pertrans->sortslot;
+			tmpcontext->ecxt_innertuple = pertrans->uniqslot;
+
+			if (!pertrans->haslast ||
+				!ExecQual(pertrans->equalfnMulti, tmpcontext))
+			{
+				if (pertrans->haslast)
+					ExecClearTuple(pertrans->uniqslot);
+
+				pertrans->haslast = true;
+				ExecCopySlot(pertrans->uniqslot, pertrans->sortslot);
+				EEO_NEXT();
+			}
+			EEO_JUMP(op->d.agg_presorted_distinctcheck.jumpdistinct);
+		}
+
 		/* process single-column ordered aggregate datum */
 		EEO_CASE(EEOP_AGG_ORDERED_TRANS_DATUM)
 		{
diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c
index e28d53c17b..80689ea466 100644
--- a/src/backend/executor/nodeAgg.c
+++ b/src/backend/executor/nodeAgg.c
@@ -1342,6 +1342,21 @@ finalize_aggregates(AggState *aggstate,
 												pertrans,
 												pergroupstate);
 		}
+		else if (pertrans->numDistinctCols > 0 && pertrans->haslast)
+		{
+			pertrans->haslast = false;
+
+			if (pertrans->numDistinctCols == 1)
+			{
+				if (!pertrans->inputtypeByVal && !pertrans->lastisnull)
+					pfree(DatumGetPointer(pertrans->lastdatum));
+
+				pertrans->lastisnull = false;
+				pertrans->lastdatum = (Datum) 0;
+			}
+			else
+				ExecClearTuple(pertrans->uniqslot);
+		}
 	}
 
 	/*
@@ -4234,10 +4249,8 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
 		numSortCols = numDistinctCols = 0;
 		pertrans->aggsortrequired = false;
 	}
-	else if (aggref->aggpresorted)
+	else if (aggref->aggpresorted && aggref->aggdistinct == NIL)
 	{
-		/* DISTINCT not yet supported for aggpresorted */
-		Assert(aggref->aggdistinct == NIL);
 		sortlist = NIL;
 		numSortCols = numDistinctCols = 0;
 		pertrans->aggsortrequired = false;
@@ -4247,7 +4260,7 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
 		sortlist = aggref->aggdistinct;
 		numSortCols = numDistinctCols = list_length(sortlist);
 		Assert(numSortCols >= list_length(aggref->aggorder));
-		pertrans->aggsortrequired = true;
+		pertrans->aggsortrequired = !aggref->aggpresorted;
 	}
 	else
 	{
diff --git a/src/backend/optimizer/plan/planner.c b/src/backend/optimizer/plan/planner.c
index d5b184ab52..a6c0a639f9 100644
--- a/src/backend/optimizer/plan/planner.c
+++ b/src/backend/optimizer/plan/planner.c
@@ -3081,7 +3081,7 @@ standard_qp_callback(PlannerInfo *root, void *extra)
 	else
 		root->group_pathkeys = NIL;
 
-	/* Determine pathkeys for aggregate functions with an ORDER BY */
+	/* Determine pathkeys for aggregate functions with DISTINCT/ORDER BY */
 	if (parse->groupingSets == NIL && root->numOrderedAggs > 0 &&
 		(qp_extra->groupClause == NIL || root->group_pathkeys))
 	{
@@ -3097,15 +3097,15 @@ standard_qp_callback(PlannerInfo *root, void *extra)
 			if (AGGKIND_IS_ORDERED_SET(aggref->aggkind))
 				continue;
 
-			/* DISTINCT aggregates not yet supported by the planner */
 			if (aggref->aggdistinct != NIL)
-				continue;
-
-			if (aggref->aggorder != NIL)
+				sortlist = aggref->aggdistinct;
+			else if (aggref->aggorder != NIL)
 				sortlist = aggref->aggorder;
 			else
 				continue;
 
+			Assert(sortlist != NIL);
+
 			/*
 			 * Find the pathkeys with the most sorted derivative of the first
 			 * Aggref. For example, if we determine the pathkeys for the first
diff --git a/src/include/executor/execExpr.h b/src/include/executor/execExpr.h
index 6a24341faa..633eb809ac 100644
--- a/src/include/executor/execExpr.h
+++ b/src/include/executor/execExpr.h
@@ -252,6 +252,8 @@ typedef enum ExprEvalOp
 	EEOP_AGG_PLAIN_TRANS_INIT_STRICT_BYREF,
 	EEOP_AGG_PLAIN_TRANS_STRICT_BYREF,
 	EEOP_AGG_PLAIN_TRANS_BYREF,
+	EEOP_AGG_PRESORTED_DISTINCT_SINGLE,
+	EEOP_AGG_PRESORTED_DISTINCT_MULTI,
 	EEOP_AGG_ORDERED_TRANS_DATUM,
 	EEOP_AGG_ORDERED_TRANS_TUPLE,
 
@@ -658,6 +660,17 @@ typedef struct ExprEvalStep
 			int			jumpnull;
 		}			agg_plain_pergroup_nullcheck;
 
+		/* for EEOP_AGG_PRESORTED_DISTINCT_{SINGLE,MULTI} */
+		struct
+		{
+			AggStatePerTrans pertrans;
+			ExprContext *aggcontext;
+			int			setno;
+			int			transno;
+			int			setoff;
+			int			jumpdistinct;
+		}			agg_presorted_distinctcheck;
+
 		/* for EEOP_AGG_PLAIN_TRANS_[INIT_][STRICT_]{BYVAL,BYREF} */
 		/* for EEOP_AGG_ORDERED_TRANS_{DATUM,TUPLE} */
 		struct
diff --git a/src/include/executor/nodeAgg.h b/src/include/executor/nodeAgg.h
index bcd0643699..22976655e1 100644
--- a/src/include/executor/nodeAgg.h
+++ b/src/include/executor/nodeAgg.h
@@ -49,7 +49,8 @@ typedef struct AggStatePerTransData
 	bool		aggshared;
 
 	/*
-	 * True for ORDER BY aggregates that are not Aggref->aggpresorted
+	 * True for ORDER BY / DISTINCT aggregates that are not
+	 * Aggref->aggpresorted
 	 */
 	bool		aggsortrequired;
 
@@ -141,6 +142,9 @@ typedef struct AggStatePerTransData
 	TupleTableSlot *sortslot;	/* current input tuple */
 	TupleTableSlot *uniqslot;	/* used for multi-column DISTINCT */
 	TupleDesc	sortdesc;		/* descriptor of input tuples */
+	Datum		lastdatum;		/* used for single-column DISTINCT */
+	bool		lastisnull;		/* used for single-column DISTINCT */
+	bool		haslast;		/* got a last value for DISTINCT check */
 
 	/*
 	 * These values are working state that is initialized at the start of an
diff --git a/src/include/nodes/primnodes.h b/src/include/nodes/primnodes.h
index 2f3cc39d4b..ad8e6cc8d9 100644
--- a/src/include/nodes/primnodes.h
+++ b/src/include/nodes/primnodes.h
@@ -304,8 +304,8 @@ typedef struct Param
  * replaced with a single argument representing the partial-aggregate
  * transition values.
  *
- * aggpresorted is set by the query planner for ORDER BY aggregates where the
- * query plan chosen provides presorted input for the executor.
+ * aggpresorted is set by the query planner for ORDER BY / DISTINCT aggregates
+ * where the query plan chosen provides presorted input for the executor.
  *
  * aggsplit indicates the expected partial-aggregation mode for the Aggref's
  * parent plan node.  It's always set to AGGSPLIT_SIMPLE in the parser, but
diff --git a/src/test/regress/expected/aggregates.out b/src/test/regress/expected/aggregates.out
index ca06d41dd0..db45ba0aba 100644
--- a/src/test/regress/expected/aggregates.out
+++ b/src/test/regress/expected/aggregates.out
@@ -2224,8 +2224,8 @@ NOTICE:  avg_transfn called with 3
 -- shouldn't share states due to the distinctness not matching.
 select my_avg(distinct one),my_sum(one) from (values(1),(3)) t(one);
 NOTICE:  avg_transfn called with 1
-NOTICE:  avg_transfn called with 3
 NOTICE:  avg_transfn called with 1
+NOTICE:  avg_transfn called with 3
 NOTICE:  avg_transfn called with 3
  my_avg | my_sum 
 --------+--------
diff --git a/src/test/regress/expected/partition_aggregate.out b/src/test/regress/expected/partition_aggregate.out
index 484c94e585..72c240c9f7 100644
--- a/src/test/regress/expected/partition_aggregate.out
+++ b/src/test/regress/expected/partition_aggregate.out
@@ -959,13 +959,13 @@ SELECT a, sum(b), array_agg(distinct c), count(*) FROM pagg_tab_ml GROUP BY a HA
                      Group Key: pagg_tab_ml.a
                      Filter: (avg(pagg_tab_ml.b) < '3'::numeric)
                      ->  Sort
-                           Sort Key: pagg_tab_ml.a
+                           Sort Key: pagg_tab_ml.a, pagg_tab_ml.c
                            ->  Seq Scan on pagg_tab_ml_p1 pagg_tab_ml
                ->  GroupAggregate
                      Group Key: pagg_tab_ml_5.a
                      Filter: (avg(pagg_tab_ml_5.b) < '3'::numeric)
                      ->  Sort
-                           Sort Key: pagg_tab_ml_5.a
+                           Sort Key: pagg_tab_ml_5.a, pagg_tab_ml_5.c
                            ->  Append
                                  ->  Seq Scan on pagg_tab_ml_p3_s1 pagg_tab_ml_5
                                  ->  Seq Scan on pagg_tab_ml_p3_s2 pagg_tab_ml_6
@@ -973,7 +973,7 @@ SELECT a, sum(b), array_agg(distinct c), count(*) FROM pagg_tab_ml GROUP BY a HA
                      Group Key: pagg_tab_ml_2.a
                      Filter: (avg(pagg_tab_ml_2.b) < '3'::numeric)
                      ->  Sort
-                           Sort Key: pagg_tab_ml_2.a
+                           Sort Key: pagg_tab_ml_2.a, pagg_tab_ml_2.c
                            ->  Append
                                  ->  Seq Scan on pagg_tab_ml_p2_s1 pagg_tab_ml_2
                                  ->  Seq Scan on pagg_tab_ml_p2_s2 pagg_tab_ml_3
@@ -1005,13 +1005,13 @@ SELECT a, sum(b), array_agg(distinct c), count(*) FROM pagg_tab_ml GROUP BY a HA
                Group Key: pagg_tab_ml.a
                Filter: (avg(pagg_tab_ml.b) < '3'::numeric)
                ->  Sort
-                     Sort Key: pagg_tab_ml.a
+                     Sort Key: pagg_tab_ml.a, pagg_tab_ml.c
                      ->  Seq Scan on pagg_tab_ml_p1 pagg_tab_ml
          ->  GroupAggregate
                Group Key: pagg_tab_ml_5.a
                Filter: (avg(pagg_tab_ml_5.b) < '3'::numeric)
                ->  Sort
-                     Sort Key: pagg_tab_ml_5.a
+                     Sort Key: pagg_tab_ml_5.a, pagg_tab_ml_5.c
                      ->  Append
                            ->  Seq Scan on pagg_tab_ml_p3_s1 pagg_tab_ml_5
                            ->  Seq Scan on pagg_tab_ml_p3_s2 pagg_tab_ml_6
@@ -1019,7 +1019,7 @@ SELECT a, sum(b), array_agg(distinct c), count(*) FROM pagg_tab_ml GROUP BY a HA
                Group Key: pagg_tab_ml_2.a
                Filter: (avg(pagg_tab_ml_2.b) < '3'::numeric)
                ->  Sort
-                     Sort Key: pagg_tab_ml_2.a
+                     Sort Key: pagg_tab_ml_2.a, pagg_tab_ml_2.c
                      ->  Append
                            ->  Seq Scan on pagg_tab_ml_p2_s1 pagg_tab_ml_2
                            ->  Seq Scan on pagg_tab_ml_p2_s2 pagg_tab_ml_3
diff --git a/src/test/regress/expected/tuplesort.out b/src/test/regress/expected/tuplesort.out
index 418f296a3f..ef79574ecf 100644
--- a/src/test/regress/expected/tuplesort.out
+++ b/src/test/regress/expected/tuplesort.out
@@ -622,15 +622,17 @@ EXPLAIN (COSTS OFF) :qry;
          ->  GroupAggregate
                Group Key: a.col12
                Filter: (count(*) > 1)
-               ->  Merge Join
-                     Merge Cond: (a.col12 = b.col12)
-                     ->  Sort
-                           Sort Key: a.col12 DESC
-                           ->  Seq Scan on test_mark_restore a
-                     ->  Sort
-                           Sort Key: b.col12 DESC
-                           ->  Seq Scan on test_mark_restore b
-(14 rows)
+               ->  Sort
+                     Sort Key: a.col12 DESC, a.col1
+                     ->  Merge Join
+                           Merge Cond: (a.col12 = b.col12)
+                           ->  Sort
+                                 Sort Key: a.col12
+                                 ->  Seq Scan on test_mark_restore a
+                           ->  Sort
+                                 Sort Key: b.col12
+                                 ->  Seq Scan on test_mark_restore b
+(16 rows)
 
 :qry;
  col12 | count | count | count | count | count 
@@ -658,15 +660,17 @@ EXPLAIN (COSTS OFF) :qry;
          ->  GroupAggregate
                Group Key: a.col12
                Filter: (count(*) > 1)
-               ->  Merge Join
-                     Merge Cond: (a.col12 = b.col12)
-                     ->  Sort
-                           Sort Key: a.col12 DESC
-                           ->  Seq Scan on test_mark_restore a
-                     ->  Sort
-                           Sort Key: b.col12 DESC
-                           ->  Seq Scan on test_mark_restore b
-(14 rows)
+               ->  Sort
+                     Sort Key: a.col12 DESC, a.col1
+                     ->  Merge Join
+                           Merge Cond: (a.col12 = b.col12)
+                           ->  Sort
+                                 Sort Key: a.col12
+                                 ->  Seq Scan on test_mark_restore a
+                           ->  Sort
+                                 Sort Key: b.col12
+                                 ->  Seq Scan on test_mark_restore b
+(16 rows)
 
 :qry;
  col12 | count | count | count | count | count 
-- 
2.30.2

