From 8426609b716cfc7cd8842f8d22af4e4f5e2935b2 Mon Sep 17 00:00:00 2001
From: Richard Guo <guofenglinux@gmail.com>
Date: Fri, 9 Dec 2022 12:07:15 +0800
Subject: [PATCH v1] Check lateral references within PHVs for memoize cache
 keys

---
 src/backend/optimizer/plan/initsplan.c | 51 ++++++++++++++++++++++++++
 src/test/regress/expected/memoize.out  | 31 ++++++++++++++++
 src/test/regress/sql/memoize.sql       | 11 ++++++
 3 files changed, 93 insertions(+)

diff --git a/src/backend/optimizer/plan/initsplan.c b/src/backend/optimizer/plan/initsplan.c
index fd8cbb1dc7..68ff0975d1 100644
--- a/src/backend/optimizer/plan/initsplan.c
+++ b/src/backend/optimizer/plan/initsplan.c
@@ -523,12 +523,49 @@ create_lateral_join_info(PlannerInfo *root)
 		PlaceHolderInfo *phinfo = (PlaceHolderInfo *) lfirst(lc);
 		Relids		eval_at = phinfo->ph_eval_at;
 		int			varno;
+		List	   *vars;
+		List	   *ph_lateral_vars;
+		ListCell   *cell;
 
 		if (phinfo->ph_lateral == NULL)
 			continue;			/* PHV is uninteresting if no lateral refs */
 
 		found_laterals = true;
 
+		/* Fetch Vars and PHVs of lateral references within PlaceHolderVars */
+		vars = pull_vars_of_level((Node *) phinfo->ph_var->phexpr, 0);
+
+		ph_lateral_vars = NIL;
+		foreach(cell, vars)
+		{
+			Node	   *node = (Node *) lfirst(cell);
+
+			node = copyObject(node);
+			if (IsA(node, Var))
+			{
+				Var		   *var = (Var *) node;
+
+				Assert(var->varlevelsup == 0);
+
+				if (bms_is_member(var->varno, phinfo->ph_lateral))
+					ph_lateral_vars = lappend(ph_lateral_vars, node);
+			}
+			else if (IsA(node, PlaceHolderVar))
+			{
+				PlaceHolderVar *phv = (PlaceHolderVar *) node;
+
+				Assert(phv->phlevelsup == 0);
+
+				if (bms_is_subset(find_placeholder_info(root, phv)->ph_eval_at,
+								  phinfo->ph_lateral))
+					ph_lateral_vars = lappend(ph_lateral_vars, node);
+			}
+			else
+				Assert(false);
+		}
+
+		list_free(vars);
+
 		if (bms_get_singleton_member(eval_at, &varno))
 		{
 			/* Evaluation site is a baserel */
@@ -540,6 +577,13 @@ create_lateral_join_info(PlannerInfo *root)
 			brel->lateral_relids =
 				bms_add_members(brel->lateral_relids,
 								phinfo->ph_lateral);
+
+			/*
+			 * Remember the lateral references. We need this info for searching
+			 * memoize cache keys.
+			 */
+			brel->lateral_vars =
+				list_concat(brel->lateral_vars, ph_lateral_vars);
 		}
 		else
 		{
@@ -551,6 +595,13 @@ create_lateral_join_info(PlannerInfo *root)
 
 				brel->lateral_relids = bms_add_members(brel->lateral_relids,
 													   phinfo->ph_lateral);
+
+				/*
+				 * Remember the lateral references. We need this info for
+				 * searching memoize cache keys.
+				 */
+				brel->lateral_vars = list_concat(brel->lateral_vars,
+												 ph_lateral_vars);
 			}
 		}
 	}
diff --git a/src/test/regress/expected/memoize.out b/src/test/regress/expected/memoize.out
index de43afa76e..cf238661d2 100644
--- a/src/test/regress/expected/memoize.out
+++ b/src/test/regress/expected/memoize.out
@@ -90,6 +90,37 @@ WHERE t1.unique1 < 1000;
   1000 | 9.5000000000000000
 (1 row)
 
+-- Try with LATERAL references within PlaceHolderVars
+SELECT explain_memoize('
+SELECT COUNT(*),AVG(t1.twenty) FROM tenk1 t1 LEFT JOIN
+LATERAL (SELECT t1.twenty as c1, t2.unique1 as c2 FROM tenk1 t2) s on true
+WHERE s.c1 = s.c2 and t1.unique1 < 1000;', false);
+                                      explain_memoize                                      
+-------------------------------------------------------------------------------------------
+ Aggregate (actual rows=1 loops=N)
+   ->  Nested Loop (actual rows=1000 loops=N)
+         ->  Seq Scan on tenk1 t1 (actual rows=1000 loops=N)
+               Filter: (unique1 < 1000)
+               Rows Removed by Filter: 9000
+         ->  Memoize (actual rows=1 loops=N)
+               Cache Key: t1.twenty
+               Cache Mode: binary
+               Hits: 980  Misses: 20  Evictions: Zero  Overflows: 0  Memory Usage: NkB
+               ->  Index Only Scan using tenk1_unique1 on tenk1 t2 (actual rows=1 loops=N)
+                     Filter: (t1.twenty = unique1)
+                     Rows Removed by Filter: 9999
+                     Heap Fetches: N
+(13 rows)
+
+-- And check we get the expected results.
+SELECT COUNT(*),AVG(t1.twenty) FROM tenk1 t1 LEFT JOIN
+LATERAL (SELECT t1.twenty as c1, t2.unique1 as c2 FROM tenk1 t2) s on true
+WHERE s.c1 = s.c2 and t1.unique1 < 1000;
+ count |        avg         
+-------+--------------------
+  1000 | 9.5000000000000000
+(1 row)
+
 -- Reduce work_mem and hash_mem_multiplier so that we see some cache evictions
 SET work_mem TO '64kB';
 SET hash_mem_multiplier TO 1.0;
diff --git a/src/test/regress/sql/memoize.sql b/src/test/regress/sql/memoize.sql
index 17c5b4bfab..512c6dfeba 100644
--- a/src/test/regress/sql/memoize.sql
+++ b/src/test/regress/sql/memoize.sql
@@ -55,6 +55,17 @@ SELECT COUNT(*),AVG(t2.unique1) FROM tenk1 t1,
 LATERAL (SELECT t2.unique1 FROM tenk1 t2 WHERE t1.twenty = t2.unique1) t2
 WHERE t1.unique1 < 1000;
 
+-- Try with LATERAL references within PlaceHolderVars
+SELECT explain_memoize('
+SELECT COUNT(*),AVG(t1.twenty) FROM tenk1 t1 LEFT JOIN
+LATERAL (SELECT t1.twenty as c1, t2.unique1 as c2 FROM tenk1 t2) s on true
+WHERE s.c1 = s.c2 and t1.unique1 < 1000;', false);
+
+-- And check we get the expected results.
+SELECT COUNT(*),AVG(t1.twenty) FROM tenk1 t1 LEFT JOIN
+LATERAL (SELECT t1.twenty as c1, t2.unique1 as c2 FROM tenk1 t2) s on true
+WHERE s.c1 = s.c2 and t1.unique1 < 1000;
+
 -- Reduce work_mem and hash_mem_multiplier so that we see some cache evictions
 SET work_mem TO '64kB';
 SET hash_mem_multiplier TO 1.0;
-- 
2.31.0

