From 40b51c8b58a5b6eef5b7c21bdea9c6d174a24054 Mon Sep 17 00:00:00 2001
From: Andy Fan <zhihui.fan1213@gmail.com>
Date: Sun, 9 Oct 2022 17:47:23 +0800
Subject: [PATCH v1] Pulling up direct-correlated ANY_SUBLINK

Due to the current convert_ANY_sublink_to_join implementation, it
can't handle the sublinks with varlevelsup equal 1 well.

  if (contain_vars_of_level((Node *) subselect, 1))
	return NULL;

This patch will try to transform such ANY_SUBLINK to EXISTS_SUBLINK
for such queries.
---
 .../postgres_fdw/expected/postgres_fdw.out    |  24 +-
 src/backend/optimizer/prep/prepjointree.c     | 219 ++++++++++++++++++
 src/test/regress/expected/join.out            |  35 +--
 src/test/regress/expected/subselect.out       | 121 ++++++++++
 src/test/regress/sql/subselect.sql            |  61 +++++
 5 files changed, 431 insertions(+), 29 deletions(-)

diff --git a/contrib/postgres_fdw/expected/postgres_fdw.out b/contrib/postgres_fdw/expected/postgres_fdw.out
index cc9e39c4a5f..4da664e6899 100644
--- a/contrib/postgres_fdw/expected/postgres_fdw.out
+++ b/contrib/postgres_fdw/expected/postgres_fdw.out
@@ -11267,19 +11267,19 @@ CREATE FOREIGN TABLE foreign_tbl2 () INHERITS (foreign_tbl)
   SERVER loopback OPTIONS (table_name 'base_tbl');
 EXPLAIN (VERBOSE, COSTS OFF)
 SELECT a FROM base_tbl WHERE a IN (SELECT a FROM foreign_tbl);
-                                 QUERY PLAN                                  
------------------------------------------------------------------------------
- Seq Scan on public.base_tbl
+                                QUERY PLAN                                 
+---------------------------------------------------------------------------
+ Nested Loop Semi Join
    Output: base_tbl.a
-   Filter: (SubPlan 1)
-   SubPlan 1
-     ->  Result
-           Output: base_tbl.a
-           ->  Append
-                 ->  Async Foreign Scan on public.foreign_tbl foreign_tbl_1
-                       Remote SQL: SELECT NULL FROM public.base_tbl
-                 ->  Async Foreign Scan on public.foreign_tbl2 foreign_tbl_2
-                       Remote SQL: SELECT NULL FROM public.base_tbl
+   ->  Seq Scan on public.base_tbl
+         Output: base_tbl.a, base_tbl.b
+         Filter: (base_tbl.a IS NOT NULL)
+   ->  Materialize
+         ->  Append
+               ->  Async Foreign Scan on public.foreign_tbl foreign_tbl_1
+                     Remote SQL: SELECT NULL FROM public.base_tbl
+               ->  Async Foreign Scan on public.foreign_tbl2 foreign_tbl_2
+                     Remote SQL: SELECT NULL FROM public.base_tbl
 (11 rows)
 
 SELECT a FROM base_tbl WHERE a IN (SELECT a FROM foreign_tbl);
diff --git a/src/backend/optimizer/prep/prepjointree.c b/src/backend/optimizer/prep/prepjointree.c
index 41c7066d90a..0d3e84d4839 100644
--- a/src/backend/optimizer/prep/prepjointree.c
+++ b/src/backend/optimizer/prep/prepjointree.c
@@ -130,6 +130,7 @@ static void substitute_phv_relids(Node *node,
 static void fix_append_rel_relids(List *append_rel_list, int varno,
 								  Relids subrelids);
 static Node *find_jointree_node_for_rel(Node *jtnode, int relid);
+static void transform_IN_sublink_to_EXIST_recurse(Node *jtnode);
 
 
 /*
@@ -256,6 +257,222 @@ replace_empty_jointree(Query *parse)
 	parse->jointree->fromlist = list_make1(rtr);
 }
 
+
+/*
+ * is_IN_sublink
+ *
+ * 	Check if the sublink is a IN sublink.
+ */
+static bool
+is_IN_sublink(SubLink *sublink)
+{
+	const char* operName;
+
+	if (sublink->subLinkType != ANY_SUBLINK || list_length(sublink->operName) != 1)
+		return false;
+
+	operName = linitial_node(String, sublink->operName)->sval;
+
+	return strcmp(operName, "=") == 0;
+}
+
+
+/*
+ * replace_param_sublink_node
+ *
+ *	Replace the PARAM_SUBLINK in src with target.
+ */
+static Node *
+replace_param_sublink_node(Node *src, Node *target)
+{
+
+	if (IsA(src, Param))
+		return target;
+
+	switch (nodeTag(src))
+	{
+		case T_RelabelType:
+			{
+				RelabelType *rtype = castNode(RelabelType, src);
+				rtype->arg = (Expr *)target;
+				break;
+			}
+		case T_FuncExpr:
+			{
+				FuncExpr *fexpr = castNode(FuncExpr, src);
+				Assert(list_length(fexpr->args));
+				Assert(linitial_node(Param, fexpr->args)->paramkind == PARAM_SUBLINK);
+				linitial(fexpr->args) = target;
+				break;
+			}
+		default:
+			{
+				Assert(false);
+				elog(ERROR, "Unexpected node type: %d", nodeTag(src));
+			}
+	}
+
+	/* src is in-placed updated. */
+	return src;
+	
+}
+
+/*
+ * transform_IN_sublink_to_EXIST_qual_recurse
+ *
+ *   Transform IN-SUBLINK with level-1 var to EXISTS-SUBLINK recursly.
+ */
+static Node *
+transform_IN_sublink_to_EXIST_qual_recurse(Node *node)
+{
+	if (node == NULL)
+		return NULL;
+
+	if (IsA(node, SubLink))
+	{
+		SubLink *sublink = (SubLink *) node;
+		Query *subselect = (Query *)sublink->subselect;
+		FromExpr *sub_fromexpr;
+
+		Assert(IsA(subselect, Query));
+
+		if (!is_IN_sublink(sublink) ||
+			!contain_vars_of_level((Node *) subselect, 1) ||
+			list_length(subselect->rtable) == 0 ||
+			subselect->hasWindowFuncs ||
+			subselect->hasAggs ||
+			subselect->hasTargetSRFs)
+		{
+			/*
+			 * WindowFunc and AggFunc can't be used in qual, so can't do the transform.
+			 * recurse to subselect's jointree anyway.
+			 *
+			 * Only transform the In-Sublink with level-1 var is useful. For the other
+			 * cases convert_ANY_sublink_to_join can handle it and may handle it better
+			 * for example Hashed SubPlan.
+			 */
+			transform_IN_sublink_to_EXIST_recurse((Node *) subselect->jointree);
+			return node;
+		}
+
+		/*
+		 * make up the push-downed node from sublink->testexpr which
+		 * will be set to NULL later, so in-place update would be OK.
+		 */
+		IncrementVarSublevelsUp(sublink->testexpr, 1, 0);
+
+		if (is_andclause(sublink->testexpr))
+		{
+			BoolExpr *and_expr = castNode(BoolExpr, sublink->testexpr);
+			ListCell *l1, *l2;
+			forboth(l1, and_expr->args, l2, subselect->targetList)
+			{
+				OpExpr *opexpr = lfirst_node(OpExpr, l1);
+				TargetEntry *tle = lfirst_node(TargetEntry, l2);
+				lsecond(opexpr->args) = replace_param_sublink_node(lsecond(opexpr->args),
+																   (Node *) tle->expr);
+			}
+		}
+		else
+		{
+			OpExpr *opexpr = (OpExpr *) sublink->testexpr;
+			TargetEntry *tle = linitial_node(TargetEntry, subselect->targetList);
+			Assert(IsA(sublink->testexpr, OpExpr));
+			lsecond(opexpr->args) = replace_param_sublink_node(lsecond(opexpr->args),
+															   (Node *) tle->expr);
+		}
+
+		sub_fromexpr = subselect->jointree;
+		if (sub_fromexpr->quals == NULL)
+			sub_fromexpr->quals = sublink->testexpr;
+		else
+			sub_fromexpr->quals = make_and_qual(sub_fromexpr->quals,
+												(Node *) sublink->testexpr);
+
+		transform_IN_sublink_to_EXIST_recurse((Node *)sub_fromexpr);
+
+		/*
+		 * Turn the IN-Sublink to exist-SUBLINK for the parent query.
+		 * sublink->subselect has already been modified.
+		 */
+		sublink->subLinkType = EXISTS_SUBLINK;
+		sublink->operName = NIL;
+		sublink->testexpr = NULL;
+
+		return node;
+	}
+
+	if (is_andclause(node))
+	{
+		List	*newclauses = NIL;
+		ListCell	*l;
+		foreach(l, ((BoolExpr *) node)->args)
+		{
+			Node	*oldclause = (Node *) lfirst(l);
+			Node	*newclause;
+
+			newclause = transform_IN_sublink_to_EXIST_qual_recurse(oldclause);
+			newclauses = lappend(newclauses, newclause);
+		}
+
+		if (newclauses == NIL)
+			return NULL;
+		else if (list_length(newclauses) == 1)
+			return (Node *) linitial(newclauses);
+		else
+			return (Node *) make_andclause(newclauses);
+	}
+	else if (is_notclause(node))
+	{
+		/*
+		 * NOT-IN can't be converted into NOT-exists, the IN sublink in
+		 * the subselect can be converted during the next subquery_planner.
+		 */
+		return node;
+	}
+
+	return node;
+}
+
+/*
+ * transform_IN_sublink_to_EXIST_recurse
+ *
+ *	Transform IN sublink to EXIST sublink if it benefits for sublink
+ * pull-ups.
+ */
+static void
+transform_IN_sublink_to_EXIST_recurse(Node *jtnode)
+{
+	if (jtnode == NULL || IsA(jtnode, RangeTblRef))
+	{
+		return;
+	}
+	else if (IsA(jtnode, FromExpr))
+	{
+		FromExpr *f = (FromExpr *) jtnode;
+		ListCell *l;
+		foreach(l, f->fromlist)
+		{
+			transform_IN_sublink_to_EXIST_recurse(lfirst(l));
+		}
+		f->quals = transform_IN_sublink_to_EXIST_qual_recurse(f->quals);
+	}
+	else if (IsA(jtnode, JoinExpr))
+	{
+		JoinExpr *j = (JoinExpr *) jtnode;
+		transform_IN_sublink_to_EXIST_recurse(j->larg);
+		transform_IN_sublink_to_EXIST_recurse(j->rarg);
+		
+		j->quals = transform_IN_sublink_to_EXIST_qual_recurse(j->quals);
+	}
+	else
+	{
+		elog(ERROR, "unrecognized node type: %d",
+			 (int) nodeTag(jtnode));
+	}
+}
+
+
 /*
  * pull_up_sublinks
  *		Attempt to pull up ANY and EXISTS SubLinks to be treated as
@@ -290,6 +507,8 @@ pull_up_sublinks(PlannerInfo *root)
 	Node	   *jtnode;
 	Relids		relids;
 
+	transform_IN_sublink_to_EXIST_recurse((Node *)root->parse->jointree);
+
 	/* Begin recursion through the jointree */
 	jtnode = pull_up_sublinks_jointree_recurse(root,
 											   (Node *) root->parse->jointree,
diff --git a/src/test/regress/expected/join.out b/src/test/regress/expected/join.out
index 08334761ae6..6d5ccb7de2f 100644
--- a/src/test/regress/expected/join.out
+++ b/src/test/regress/expected/join.out
@@ -5983,8 +5983,8 @@ lateral (select * from int8_tbl t1,
                                      where q2 = (select greatest(t1.q1,t2.q2))
                                        and (select v.id=0)) offset 0) ss2) ss
          where t1.q1 = ss.q2) ss0;
-                           QUERY PLAN                            
------------------------------------------------------------------
+                             QUERY PLAN                              
+---------------------------------------------------------------------
  Nested Loop
    Output: "*VALUES*".column1, t1.q1, t1.q2, ss2.q1, ss2.q2
    ->  Seq Scan on public.int8_tbl t1
@@ -5996,23 +5996,24 @@ lateral (select * from int8_tbl t1,
          ->  Subquery Scan on ss2
                Output: ss2.q1, ss2.q2
                Filter: (t1.q1 = ss2.q2)
-               ->  Seq Scan on public.int8_tbl t2
+               ->  Result
                      Output: t2.q1, t2.q2
-                     Filter: (SubPlan 3)
-                     SubPlan 3
+                     One-Time Filter: $3
+                     InitPlan 2 (returns $3)
                        ->  Result
-                             Output: t3.q2
-                             One-Time Filter: $4
-                             InitPlan 1 (returns $2)
-                               ->  Result
-                                     Output: GREATEST($0, t2.q2)
-                             InitPlan 2 (returns $4)
-                               ->  Result
-                                     Output: ($3 = 0)
-                             ->  Seq Scan on public.int8_tbl t3
-                                   Output: t3.q1, t3.q2
-                                   Filter: (t3.q2 = $2)
-(27 rows)
+                             Output: ($2 = 0)
+                     ->  Nested Loop Semi Join
+                           Output: t2.q1, t2.q2
+                           Join Filter: (t2.q1 = t3.q2)
+                           ->  Seq Scan on public.int8_tbl t2
+                                 Output: t2.q1, t2.q2
+                                 Filter: ((SubPlan 1) = t2.q1)
+                                 SubPlan 1
+                                   ->  Result
+                                         Output: GREATEST($0, t2.q2)
+                           ->  Seq Scan on public.int8_tbl t3
+                                 Output: t3.q1, t3.q2
+(28 rows)
 
 select * from (values (0), (1)) v(id),
 lateral (select * from int8_tbl t1,
diff --git a/src/test/regress/expected/subselect.out b/src/test/regress/expected/subselect.out
index 63d26d44fc3..5e0ce397233 100644
--- a/src/test/regress/expected/subselect.out
+++ b/src/test/regress/expected/subselect.out
@@ -1926,3 +1926,124 @@ select * from x for update;
    Output: subselect_tbl.f1, subselect_tbl.f2, subselect_tbl.f3
 (2 rows)
 
+-- Test transform the level-1 in-sublink to existing sublink.
+create temp table temp_t1 (a int, b int, c int) on commit delete rows;
+create temp table temp_t2 (a int, b int, c int) on commit delete rows;
+create temp table temp_t3 (a int, b int, c int) on commit delete rows;
+create temp table temp_t4 (a int, b int, c int, d int) on commit delete rows;
+begin;
+insert into temp_t1 values (1, 1, 1), (2, 2, null), (3, null, null);
+insert into temp_t2 values (1, 1, 1), (2, 2, null), (3, null, null);
+insert into temp_t3 values (1, 1, 1), (2, 2, null), (3, null, null);
+insert into temp_t4 values (1, 1, 1, 1), (2, 2, null, null), (3, null, null, null);
+analyze temp_t1;
+analyze temp_t2;
+analyze temp_t3;
+-- one-elem in subquery
+select * from temp_t1 t1 where a  in (select a from temp_t2 t2 where t2.b > t1.b);
+ a | b | c 
+---+---+---
+(0 rows)
+
+explain (costs off)
+select * from temp_t1 t1 where a  in (select a from temp_t2 t2 where t2.b > t1.b);
+             QUERY PLAN             
+------------------------------------
+ Hash Semi Join
+   Hash Cond: (t1.a = t2.a)
+   Join Filter: (t2.b > t1.b)
+   ->  Seq Scan on temp_t1 t1
+   ->  Hash
+         ->  Seq Scan on temp_t2 t2
+(6 rows)
+
+-- two-elem in subquery
+select * from temp_t1 t1 where (a, b)  in (select a, b from temp_t2 t2 where t2.c = t1.c);
+ a | b | c 
+---+---+---
+ 1 | 1 | 1
+(1 row)
+
+explain (costs off)
+select * from temp_t1 t1 where (a, b)  in (select a, b from temp_t2 t2 where t2.c = t1.c);
+                            QUERY PLAN                            
+------------------------------------------------------------------
+ Hash Semi Join
+   Hash Cond: ((t1.c = t2.c) AND (t1.a = t2.a) AND (t1.b = t2.b))
+   ->  Seq Scan on temp_t1 t1
+   ->  Hash
+         ->  Seq Scan on temp_t2 t2
+(5 rows)
+
+-- sublink in sublink
+select * from temp_t1 t1
+where (a, b) in (select a, b from temp_t2 t2
+                 where t2.c < t1.c
+		 and t2.c in (select c from temp_t3 t3 where t3.b = t2.b));
+ a | b | c 
+---+---+---
+(0 rows)
+
+explain (costs off)
+select * from temp_t1 t1
+where (a, b) in (select a, b from temp_t2 t2
+                 where t2.c < t1.c
+		 and t2.c in (select c from temp_t3 t3 where t3.b = t2.b));
+                             QUERY PLAN                             
+--------------------------------------------------------------------
+ Nested Loop Semi Join
+   Join Filter: ((t2.c < t1.c) AND (t1.b = t3.b) AND (t1.a = t2.a))
+   ->  Seq Scan on temp_t1 t1
+   ->  Materialize
+         ->  Hash Semi Join
+               Hash Cond: ((t2.b = t3.b) AND (t2.c = t3.c))
+               ->  Seq Scan on temp_t2 t2
+               ->  Hash
+                     ->  Seq Scan on temp_t3 t3
+(9 rows)
+
+-- sublink in not-in sublinks. not in will not be transformed but the in-clause
+-- in the subselect should be transformed.
+explain (costs off)
+select * from temp_t1 t1
+where (a, b) not in (select a, b from temp_t2 t2
+      	     	     where t2.c < t1.c
+                     and t1.c in (SELECT c from temp_t3 t3 where t3.b = t2.b ))
+and c > 3;
+                QUERY PLAN                 
+-------------------------------------------
+ Seq Scan on temp_t1 t1
+   Filter: ((c > 3) AND (NOT (SubPlan 1)))
+   SubPlan 1
+     ->  Nested Loop Semi Join
+           Join Filter: (t2.b = t3.b)
+           ->  Seq Scan on temp_t2 t2
+                 Filter: (c < t1.c)
+           ->  Seq Scan on temp_t3 t3
+                 Filter: (t1.c = c)
+(9 rows)
+
+-- The clause in the ON-clause should be transformed.
+explain (costs off)
+select * from temp_t1 t1, (temp_t2 t2 join temp_t4 t4
+                           on t2.a in (select a from temp_t3 t3 where t4.b = t3.b)) v
+where t1.a = v.d;
+                      QUERY PLAN                      
+------------------------------------------------------
+ Nested Loop
+   Join Filter: (t3.a = t2.a)
+   ->  Hash Join
+         Hash Cond: (t4.d = t1.a)
+         ->  Hash Join
+               Hash Cond: (t4.b = t3.b)
+               ->  Seq Scan on temp_t4 t4
+               ->  Hash
+                     ->  HashAggregate
+                           Group Key: t3.b, t3.a
+                           ->  Seq Scan on temp_t3 t3
+         ->  Hash
+               ->  Seq Scan on temp_t1 t1
+   ->  Seq Scan on temp_t2 t2
+(14 rows)
+
+commit;
diff --git a/src/test/regress/sql/subselect.sql b/src/test/regress/sql/subselect.sql
index 40276708c99..83ad18cf9b1 100644
--- a/src/test/regress/sql/subselect.sql
+++ b/src/test/regress/sql/subselect.sql
@@ -968,3 +968,64 @@ select * from (with x as (select 2 as y) select * from x) ss;
 explain (verbose, costs off)
 with x as (select * from subselect_tbl)
 select * from x for update;
+
+
+-- Test transform the level-1 in-sublink to existing sublink.
+create temp table temp_t1 (a int, b int, c int) on commit delete rows;
+create temp table temp_t2 (a int, b int, c int) on commit delete rows;
+create temp table temp_t3 (a int, b int, c int) on commit delete rows;
+create temp table temp_t4 (a int, b int, c int, d int) on commit delete rows;
+
+begin;
+insert into temp_t1 values (1, 1, 1), (2, 2, null), (3, null, null);
+insert into temp_t2 values (1, 1, 1), (2, 2, null), (3, null, null);
+insert into temp_t3 values (1, 1, 1), (2, 2, null), (3, null, null);
+insert into temp_t4 values (1, 1, 1, 1), (2, 2, null, null), (3, null, null, null);
+
+analyze temp_t1;
+analyze temp_t2;
+analyze temp_t3;
+
+-- one-elem in subquery
+select * from temp_t1 t1 where a  in (select a from temp_t2 t2 where t2.b > t1.b);
+explain (costs off)
+select * from temp_t1 t1 where a  in (select a from temp_t2 t2 where t2.b > t1.b);
+
+-- two-elem in subquery
+select * from temp_t1 t1 where (a, b)  in (select a, b from temp_t2 t2 where t2.c = t1.c);
+explain (costs off)
+select * from temp_t1 t1 where (a, b)  in (select a, b from temp_t2 t2 where t2.c = t1.c);
+
+-- sublink in sublink
+select * from temp_t1 t1
+where (a, b) in (select a, b from temp_t2 t2
+                 where t2.c < t1.c
+		 and t2.c in (select c from temp_t3 t3 where t3.b = t2.b));
+explain (costs off)
+select * from temp_t1 t1
+where (a, b) in (select a, b from temp_t2 t2
+                 where t2.c < t1.c
+		 and t2.c in (select c from temp_t3 t3 where t3.b = t2.b));
+
+
+-- sublink in not-in sublinks. not in will not be transformed but the in-clause
+-- in the subselect should be transformed.
+explain (costs off)
+select * from temp_t1 t1
+where (a, b) not in (select a, b from temp_t2 t2
+      	     	     where t2.c < t1.c
+                     and t1.c in (SELECT c from temp_t3 t3 where t3.b = t2.b ))
+and c > 3;
+
+
+-- The clause in the ON-clause should be transformed.
+explain (costs off)
+select * from temp_t1 t1, (temp_t2 t2 join temp_t4 t4
+                           on t2.a in (select a from temp_t3 t3 where t4.b = t3.b)) v
+where t1.a = v.d;
+
+commit;
+
+
+
+
-- 
2.21.0

