From 4a3b4f080b113ed68c7272ba1ec776ec00516a61 Mon Sep 17 00:00:00 2001
From: amit <amitlangote09@gmail.com>
Date: Thu, 18 Jul 2019 10:33:20 +0900
Subject: [PATCH v6] Fix partitionwise join to handle FULL JOINs correctly

---
 src/backend/nodes/makefuncs.c                |  17 ++++
 src/backend/optimizer/util/relnode.c         |  43 +++++++--
 src/backend/parser/parse_clause.c            |  20 ++---
 src/include/nodes/makefuncs.h                |   2 +
 src/test/regress/expected/partition_join.out | 129 +++++++++++++++++++++++++++
 src/test/regress/sql/partition_join.sql      |  23 +++++
 6 files changed, 214 insertions(+), 20 deletions(-)

diff --git a/src/backend/nodes/makefuncs.c b/src/backend/nodes/makefuncs.c
index b442b5a..3e58fa9 100644
--- a/src/backend/nodes/makefuncs.c
+++ b/src/backend/nodes/makefuncs.c
@@ -812,3 +812,20 @@ makeVacuumRelation(RangeVar *relation, Oid oid, List *va_cols)
 	v->va_cols = va_cols;
 	return v;
 }
+
+/*
+ * makeCoalesceExpr
+ */
+CoalesceExpr *
+makeCoalesceExpr(Oid typid, Oid collid, Node *l_node, Node *r_node,
+				 int location)
+{
+	CoalesceExpr *c = makeNode(CoalesceExpr);
+
+	c->coalescetype = typid;
+	c->coalescecollid = collid;
+	c->args = list_make2(l_node, r_node);
+	c->location = location;
+
+	return c;
+}
diff --git a/src/backend/optimizer/util/relnode.c b/src/backend/optimizer/util/relnode.c
index af1fb48..a211028 100644
--- a/src/backend/optimizer/util/relnode.c
+++ b/src/backend/optimizer/util/relnode.c
@@ -17,6 +17,7 @@
 #include <limits.h>
 
 #include "miscadmin.h"
+#include "nodes/makefuncs.h"
 #include "optimizer/appendinfo.h"
 #include "optimizer/clauses.h"
 #include "optimizer/cost.h"
@@ -1890,7 +1891,8 @@ set_joinrel_partition_key_exprs(RelOptInfo *joinrel,
 								RelOptInfo *outer_rel, RelOptInfo *inner_rel,
 								JoinType jointype)
 {
-	int			partnatts = joinrel->part_scheme->partnatts;
+	PartitionScheme part_scheme = joinrel->part_scheme;
+	int			partnatts = part_scheme->partnatts;
 
 	joinrel->partexprs = (List **) palloc0(sizeof(List *) * partnatts);
 	joinrel->nullable_partexprs =
@@ -1963,12 +1965,39 @@ set_joinrel_partition_key_exprs(RelOptInfo *joinrel,
 				 * that involve strict join operators.
 				 */
 			case JOIN_FULL:
-				nullable_partexpr = list_concat_copy(outer_expr,
-													 inner_expr);
-				nullable_partexpr = list_concat(nullable_partexpr,
-												outer_null_expr);
-				nullable_partexpr = list_concat(nullable_partexpr,
-												inner_null_expr);
+				{
+					Oid		coltype = part_scheme->partopcintype[cnt],
+							colcoll = part_scheme->partcollation[cnt];
+					Node   *larg;
+					ListCell *rarg;
+					CoalesceExpr *coalesce = NULL;
+
+					nullable_partexpr = list_concat_copy(outer_expr,
+														 inner_expr);
+					nullable_partexpr = list_concat(nullable_partexpr,
+													outer_null_expr);
+					nullable_partexpr = list_concat(nullable_partexpr,
+													inner_null_expr);
+
+					/*
+					 * Add a CoalesceExpr wrapping all of the collected
+					 * nullable expressions, because clauses for a full join
+					 * written with USING () would not be matched to the
+					 * joinrel's partition keys without this.
+					 */
+					larg = (Node *) linitial(nullable_partexpr);
+					rarg = list_second_cell(nullable_partexpr);
+					while (rarg)
+					{
+						coalesce = makeCoalesceExpr(coltype, colcoll, larg,
+													(Node *) lfirst(rarg), -1);
+						larg = (Node *) coalesce;
+						rarg = lnext(nullable_partexpr, rarg);
+					}
+					Assert(coalesce != NULL);
+					nullable_partexpr = list_concat(nullable_partexpr,
+													list_make1(coalesce));
+				}
 				break;
 
 			default:
diff --git a/src/backend/parser/parse_clause.c b/src/backend/parser/parse_clause.c
index 36a3eff..bb9890d 100644
--- a/src/backend/parser/parse_clause.c
+++ b/src/backend/parser/parse_clause.c
@@ -1643,20 +1643,14 @@ buildMergedJoinVar(ParseState *pstate, JoinType jointype,
 			res_node = r_node;
 			break;
 		case JOIN_FULL:
-			{
-				/*
-				 * Here we must build a COALESCE expression to ensure that the
-				 * join output is non-null if either input is.
-				 */
-				CoalesceExpr *c = makeNode(CoalesceExpr);
+			/*
+			 * Here we must build a COALESCE expression to ensure that the
+			 * join output is non-null if either input is.
+			 */
+			res_node = (Node *) makeCoalesceExpr(outcoltype, InvalidOid,
+												 l_node, r_node, -1);
+			break;
 
-				c->coalescetype = outcoltype;
-				/* coalescecollid will get set below */
-				c->args = list_make2(l_node, r_node);
-				c->location = -1;
-				res_node = (Node *) c;
-				break;
-			}
 		default:
 			elog(ERROR, "unrecognized join type: %d", (int) jointype);
 			res_node = NULL;	/* keep compiler quiet */
diff --git a/src/include/nodes/makefuncs.h b/src/include/nodes/makefuncs.h
index 31d9aed..bb26ef5 100644
--- a/src/include/nodes/makefuncs.h
+++ b/src/include/nodes/makefuncs.h
@@ -104,5 +104,7 @@ extern DefElem *makeDefElemExtended(char *nameSpace, char *name, Node *arg,
 extern GroupingSet *makeGroupingSet(GroupingSetKind kind, List *content, int location);
 
 extern VacuumRelation *makeVacuumRelation(RangeVar *relation, Oid oid, List *va_cols);
+extern CoalesceExpr *makeCoalesceExpr(Oid typid, Oid collid, Node *l_node, Node *r_node,
+				 int location);
 
 #endif							/* MAKEFUNC_H */
diff --git a/src/test/regress/expected/partition_join.out b/src/test/regress/expected/partition_join.out
index b3fbe47..e928767 100644
--- a/src/test/regress/expected/partition_join.out
+++ b/src/test/regress/expected/partition_join.out
@@ -750,6 +750,135 @@ SELECT t1.a, t1.c, t2.b, t2.c, t3.a + t3.b, t3.c FROM (prt1 t1 LEFT JOIN prt2 t2
  550 | 0550 |     |      |     1100 | 0
 (12 rows)
 
+-- 3-way FULL JOIN
+SET enable_partitionwise_aggregate TO true;
+EXPLAIN (COSTS OFF)
+SELECT a, b FROM prt1 FULL JOIN prt2 p2(b,a,c) USING(a,b) FULL JOIN prt2 p3(b,a,c) USING (a, b)
+  WHERE a BETWEEN 490 AND 510
+  GROUP BY 1, 2 ORDER BY 1, 2;
+                                                                     QUERY PLAN                                                                      
+-----------------------------------------------------------------------------------------------------------------------------------------------------
+ Group
+   Group Key: (COALESCE(COALESCE(prt1.a, p2.a), p3.a)), (COALESCE(COALESCE(prt1.b, p2.b), p3.b))
+   ->  Merge Append
+         Sort Key: (COALESCE(COALESCE(prt1.a, p2.a), p3.a)), (COALESCE(COALESCE(prt1.b, p2.b), p3.b))
+         ->  Group
+               Group Key: (COALESCE(COALESCE(prt1.a, p2.a), p3.a)), (COALESCE(COALESCE(prt1.b, p2.b), p3.b))
+               ->  Sort
+                     Sort Key: (COALESCE(COALESCE(prt1.a, p2.a), p3.a)), (COALESCE(COALESCE(prt1.b, p2.b), p3.b))
+                     ->  Hash Full Join
+                           Hash Cond: ((COALESCE(prt1.a, p2.a) = p3.a) AND (COALESCE(prt1.b, p2.b) = p3.b))
+                           Filter: ((COALESCE(COALESCE(prt1.a, p2.a), p3.a) >= 490) AND (COALESCE(COALESCE(prt1.a, p2.a), p3.a) <= 510))
+                           ->  Hash Full Join
+                                 Hash Cond: ((prt1.a = p2.a) AND (prt1.b = p2.b))
+                                 ->  Seq Scan on prt1_p1 prt1
+                                 ->  Hash
+                                       ->  Seq Scan on prt2_p1 p2
+                           ->  Hash
+                                 ->  Seq Scan on prt2_p1 p3
+         ->  Group
+               Group Key: (COALESCE(COALESCE(prt1_1.a, p2_1.a), p3_1.a)), (COALESCE(COALESCE(prt1_1.b, p2_1.b), p3_1.b))
+               ->  Sort
+                     Sort Key: (COALESCE(COALESCE(prt1_1.a, p2_1.a), p3_1.a)), (COALESCE(COALESCE(prt1_1.b, p2_1.b), p3_1.b))
+                     ->  Hash Full Join
+                           Hash Cond: ((COALESCE(prt1_1.a, p2_1.a) = p3_1.a) AND (COALESCE(prt1_1.b, p2_1.b) = p3_1.b))
+                           Filter: ((COALESCE(COALESCE(prt1_1.a, p2_1.a), p3_1.a) >= 490) AND (COALESCE(COALESCE(prt1_1.a, p2_1.a), p3_1.a) <= 510))
+                           ->  Hash Full Join
+                                 Hash Cond: ((prt1_1.a = p2_1.a) AND (prt1_1.b = p2_1.b))
+                                 ->  Seq Scan on prt1_p2 prt1_1
+                                 ->  Hash
+                                       ->  Seq Scan on prt2_p2 p2_1
+                           ->  Hash
+                                 ->  Seq Scan on prt2_p2 p3_1
+         ->  Group
+               Group Key: (COALESCE(COALESCE(prt1_2.a, p2_2.a), p3_2.a)), (COALESCE(COALESCE(prt1_2.b, p2_2.b), p3_2.b))
+               ->  Sort
+                     Sort Key: (COALESCE(COALESCE(prt1_2.a, p2_2.a), p3_2.a)), (COALESCE(COALESCE(prt1_2.b, p2_2.b), p3_2.b))
+                     ->  Hash Full Join
+                           Hash Cond: ((COALESCE(prt1_2.a, p2_2.a) = p3_2.a) AND (COALESCE(prt1_2.b, p2_2.b) = p3_2.b))
+                           Filter: ((COALESCE(COALESCE(prt1_2.a, p2_2.a), p3_2.a) >= 490) AND (COALESCE(COALESCE(prt1_2.a, p2_2.a), p3_2.a) <= 510))
+                           ->  Hash Full Join
+                                 Hash Cond: ((prt1_2.a = p2_2.a) AND (prt1_2.b = p2_2.b))
+                                 ->  Seq Scan on prt1_p3 prt1_2
+                                 ->  Hash
+                                       ->  Seq Scan on prt2_p3 p2_2
+                           ->  Hash
+                                 ->  Seq Scan on prt2_p3 p3_2
+(46 rows)
+
+SELECT a, b FROM prt1 FULL JOIN prt2 p2(b,a,c) USING(a,b) FULL JOIN prt2 p3(b,a,c) USING (a, b)
+  WHERE a BETWEEN 490 AND 510
+  GROUP BY 1, 2 ORDER BY 1, 2;
+  a  | b  
+-----+----
+ 490 | 15
+ 492 | 17
+ 494 | 19
+ 495 | 20
+ 496 | 21
+ 498 | 23
+ 500 |  0
+ 501 |  1
+ 502 |  2
+ 504 |  4
+ 506 |  6
+ 507 |  7
+ 508 |  8
+ 510 | 10
+(14 rows)
+
+EXPLAIN (COSTS OFF)
+SELECT p1.a, p1.b FROM prt1 p1 FULL JOIN prt2 p2(b,a,c) USING(a,b) FULL JOIN prt2 p3(b,a,c) ON (p2.a = p3.a AND p2.b = p3.b)
+  WHERE p1.a BETWEEN 490 AND 510
+  GROUP BY 1, 2 ORDER BY 1, 2;
+                                   QUERY PLAN                                   
+--------------------------------------------------------------------------------
+ Group
+   Group Key: p1.a, p1.b
+   ->  Sort
+         Sort Key: p1.a, p1.b
+         ->  Append
+               ->  Nested Loop Left Join
+                     ->  Hash Right Join
+                           Hash Cond: ((p2_1.a = p1_1.a) AND (p2_1.b = p1_1.b))
+                           ->  Seq Scan on prt2_p2 p2_1
+                           ->  Hash
+                                 ->  Seq Scan on prt1_p2 p1_1
+                                       Filter: ((a >= 490) AND (a <= 510))
+                     ->  Index Scan using iprt2_p2_b on prt2_p2 p3_1
+                           Index Cond: (a = p2_1.a)
+                           Filter: (p2_1.b = b)
+               ->  Nested Loop Left Join
+                     ->  Hash Right Join
+                           Hash Cond: ((p2_2.a = p1_2.a) AND (p2_2.b = p1_2.b))
+                           ->  Seq Scan on prt2_p3 p2_2
+                           ->  Hash
+                                 ->  Seq Scan on prt1_p3 p1_2
+                                       Filter: ((a >= 490) AND (a <= 510))
+                     ->  Index Scan using iprt2_p3_b on prt2_p3 p3_2
+                           Index Cond: (a = p2_2.a)
+                           Filter: (p2_2.b = b)
+(25 rows)
+
+SELECT p1.a, p1.b FROM prt1 p1 FULL JOIN prt2 p2(b,a,c) USING(a,b) FULL JOIN prt2 p3(b,a,c) ON (p2.a = p3.a AND p2.b = p3.b)
+  WHERE p1.a BETWEEN 490 AND 510
+  GROUP BY 1, 2 ORDER BY 1, 2;
+  a  | b  
+-----+----
+ 490 | 15
+ 492 | 17
+ 494 | 19
+ 496 | 21
+ 498 | 23
+ 500 |  0
+ 502 |  2
+ 504 |  4
+ 506 |  6
+ 508 |  8
+ 510 | 10
+(11 rows)
+
+RESET enable_partitionwise_aggregate;
 -- Cases with non-nullable expressions in subquery results;
 -- make sure these go to null as expected
 EXPLAIN (COSTS OFF)
diff --git a/src/test/regress/sql/partition_join.sql b/src/test/regress/sql/partition_join.sql
index 575ba7b..d28abf7 100644
--- a/src/test/regress/sql/partition_join.sql
+++ b/src/test/regress/sql/partition_join.sql
@@ -145,6 +145,28 @@ EXPLAIN (COSTS OFF)
 SELECT t1.a, t1.c, t2.b, t2.c, t3.a + t3.b, t3.c FROM (prt1 t1 LEFT JOIN prt2 t2 ON t1.a = t2.b) RIGHT JOIN prt1_e t3 ON (t1.a = (t3.a + t3.b)/2) WHERE t3.c = 0 ORDER BY t1.a, t2.b, t3.a + t3.b;
 SELECT t1.a, t1.c, t2.b, t2.c, t3.a + t3.b, t3.c FROM (prt1 t1 LEFT JOIN prt2 t2 ON t1.a = t2.b) RIGHT JOIN prt1_e t3 ON (t1.a = (t3.a + t3.b)/2) WHERE t3.c = 0 ORDER BY t1.a, t2.b, t3.a + t3.b;
 
+-- 3-way FULL JOIN
+
+SET enable_partitionwise_aggregate TO true;
+
+EXPLAIN (COSTS OFF)
+SELECT a, b FROM prt1 FULL JOIN prt2 p2(b,a,c) USING(a,b) FULL JOIN prt2 p3(b,a,c) USING (a, b)
+  WHERE a BETWEEN 490 AND 510
+  GROUP BY 1, 2 ORDER BY 1, 2;
+SELECT a, b FROM prt1 FULL JOIN prt2 p2(b,a,c) USING(a,b) FULL JOIN prt2 p3(b,a,c) USING (a, b)
+  WHERE a BETWEEN 490 AND 510
+  GROUP BY 1, 2 ORDER BY 1, 2;
+
+EXPLAIN (COSTS OFF)
+SELECT p1.a, p1.b FROM prt1 p1 FULL JOIN prt2 p2(b,a,c) USING(a,b) FULL JOIN prt2 p3(b,a,c) ON (p2.a = p3.a AND p2.b = p3.b)
+  WHERE p1.a BETWEEN 490 AND 510
+  GROUP BY 1, 2 ORDER BY 1, 2;
+SELECT p1.a, p1.b FROM prt1 p1 FULL JOIN prt2 p2(b,a,c) USING(a,b) FULL JOIN prt2 p3(b,a,c) ON (p2.a = p3.a AND p2.b = p3.b)
+  WHERE p1.a BETWEEN 490 AND 510
+  GROUP BY 1, 2 ORDER BY 1, 2;
+
+RESET enable_partitionwise_aggregate;
+
 -- Cases with non-nullable expressions in subquery results;
 -- make sure these go to null as expected
 EXPLAIN (COSTS OFF)
@@ -285,6 +307,7 @@ EXPLAIN (COSTS OFF)
 SELECT avg(t1.a), avg(t2.b), avg(t3.a + t3.b), t1.c, t2.c, t3.c FROM pht1 t1, pht2 t2, pht1_e t3 WHERE t1.b = t2.b AND t1.c = t2.c AND ltrim(t3.c, 'A') = t1.c GROUP BY t1.c, t2.c, t3.c ORDER BY t1.c, t2.c, t3.c;
 SELECT avg(t1.a), avg(t2.b), avg(t3.a + t3.b), t1.c, t2.c, t3.c FROM pht1 t1, pht2 t2, pht1_e t3 WHERE t1.b = t2.b AND t1.c = t2.c AND ltrim(t3.c, 'A') = t1.c GROUP BY t1.c, t2.c, t3.c ORDER BY t1.c, t2.c, t3.c;
 
+
 -- test default partition behavior for range
 ALTER TABLE prt1 DETACH PARTITION prt1_p3;
 ALTER TABLE prt1 ATTACH PARTITION prt1_p3 DEFAULT;
-- 
1.8.3.1

