Hi,

As I see, initial pruning doesn't work in the case when a ScalarArrayOpExpr contains a parameter as the RHS of the expression, like following:

partkey = ANY($1)

As colleagues say, it is quite typical to use stored procedures, pass an array of IDs as a parameter, and use it in a SELECT clause.

So, here I propose a patch that extends pruning machinery. It is nothing innovative or complicated, but I'm not sure it is fully operational so far: it may need some discussion, review and polishing.

I intended to add it to the next commitfest if this feature makes sense.

--
regards, Andrei Lepikhov
From f6daa91cd6630e85a0ba37bcba41ae06c4a1fa34 Mon Sep 17 00:00:00 2001
From: "Andrei V. Lepikhov" <lepi...@gmail.com>
Date: Fri, 14 Mar 2025 12:51:42 +0100
Subject: [PATCH] Enhance partition pruning for an array parameter.

It is designed to prune partitions in case when incoming clause looks like
the following: 'partkey = ANY($1)'.

It seems quite a common case when the array is a parameter of a function.

Although the code is covered by tests the code should be carefully reviewed and
tested.
---
 src/backend/partitioning/partprune.c  |  71 +++++++++-
 src/test/regress/expected/inherit.out | 180 ++++++++++++++++++++++++++
 src/test/regress/sql/inherit.sql      |  64 +++++++++
 3 files changed, 311 insertions(+), 4 deletions(-)

diff --git a/src/backend/partitioning/partprune.c b/src/backend/partitioning/partprune.c
index 48a35f763e..622f0d230a 100644
--- a/src/backend/partitioning/partprune.c
+++ b/src/backend/partitioning/partprune.c
@@ -2179,6 +2179,9 @@ match_clause_to_partition_key(GeneratePruningStepsContext *context,
 		List	   *elem_exprs,
 				   *elem_clauses;
 		ListCell   *lc1;
+		int			strategy;
+		Oid			lefttype,
+					righttype;
 
 		if (IsA(leftop, RelabelType))
 			leftop = ((RelabelType *) leftop)->arg;
@@ -2206,10 +2209,6 @@ match_clause_to_partition_key(GeneratePruningStepsContext *context,
 			negator = get_negator(saop_op);
 			if (OidIsValid(negator) && op_in_opfamily(negator, partopfamily))
 			{
-				int			strategy;
-				Oid			lefttype,
-							righttype;
-
 				get_op_opfamily_properties(negator, partopfamily,
 										   false, &strategy,
 										   &lefttype, &righttype);
@@ -2219,6 +2218,12 @@ match_clause_to_partition_key(GeneratePruningStepsContext *context,
 			else
 				return PARTCLAUSE_NOMATCH;	/* no useful negator */
 		}
+		else
+		{
+			get_op_opfamily_properties(saop_op, partopfamily, false,
+									   &strategy, &lefttype,
+									   &righttype);
+		}
 
 		/*
 		 * Only allow strict operators.  This will guarantee nulls are
@@ -2365,6 +2370,64 @@ match_clause_to_partition_key(GeneratePruningStepsContext *context,
 			 */
 			elem_exprs = arrexpr->elements;
 		}
+		else if (IsA(rightop, Param))
+		{
+			Oid				cmpfn;
+			PartClauseInfo *partclause;
+
+			if (righttype == part_scheme->partopcintype[partkeyidx])
+				cmpfn = part_scheme->partsupfunc[partkeyidx].fn_oid;
+			else
+			{
+				switch (part_scheme->strategy)
+				{
+					/*
+					 * For range and list partitioning, we need the ordering
+					 * procedure with lefttype being the partition key's type,
+					 * and righttype the clause's operator's right type.
+					 */
+				case PARTITION_STRATEGY_LIST:
+				case PARTITION_STRATEGY_RANGE:
+					cmpfn =
+						get_opfamily_proc(part_scheme->partopfamily[partkeyidx],
+										  part_scheme->partopcintype[partkeyidx],
+										  righttype, BTORDER_PROC);
+					break;
+
+					/*
+					 * For hash partitioning, we need the hashing procedure
+					 * for the clause's type.
+					 */
+				case PARTITION_STRATEGY_HASH:
+					cmpfn =
+						get_opfamily_proc(part_scheme->partopfamily[partkeyidx],
+										  righttype, righttype,
+										  HASHEXTENDED_PROC);
+					break;
+
+				default:
+					elog(ERROR, "invalid partition strategy: %c",
+						 part_scheme->strategy);
+					cmpfn = InvalidOid; /* keep compiler quiet */
+					break;
+				}
+
+				if (!OidIsValid(cmpfn))
+					return PARTCLAUSE_NOMATCH;
+			}
+
+			partclause = (PartClauseInfo *) palloc(sizeof(PartClauseInfo));
+			partclause->keyno = partkeyidx;
+			partclause->opno = saop_op;
+			partclause->op_is_ne = false;
+			partclause->op_strategy = strategy;
+			partclause->expr = rightop;
+			partclause->cmpfn = cmpfn;
+
+			*pc = partclause;
+
+			return PARTCLAUSE_MATCH_CLAUSE;
+		}
 		else
 		{
 			/* Give up on any other clause types. */
diff --git a/src/test/regress/expected/inherit.out b/src/test/regress/expected/inherit.out
index e671975a28..d5f3137b93 100644
--- a/src/test/regress/expected/inherit.out
+++ b/src/test/regress/expected/inherit.out
@@ -3823,3 +3823,183 @@ select * from tuplesest_tab join
 
 drop table tuplesest_parted;
 drop table tuplesest_tab;
+--
+-- Test the cases for partition pruning by an expression like:
+-- partkey = ANY($1)
+--
+CREATE TABLE array_prune (id int)
+PARTITION BY HASH(id);
+CREATE TABLE array_prune_t0
+ PARTITION OF array_prune FOR VALUES WITH (modulus 2, remainder 0);
+CREATE TABLE array_prune_t1
+ PARTITION OF array_prune FOR VALUES WITH (modulus 2, remainder 1);
+CREATE FUNCTION array_prune_fn(oper text, arr text) RETURNS setof text
+LANGUAGE plpgsql AS $$
+DECLARE
+    line text;
+    query text;
+BEGIN
+  query := format('EXPLAIN (COSTS OFF) SELECT * FROM array_prune WHERE id %s (%s)', $1, $2);
+  FOR line IN EXECUTE query
+  LOOP
+    RETURN NEXT line;
+  END LOOP;
+END; $$;
+SELECT array_prune_fn('= ANY', 'ARRAY[1]'); -- prune one partition
+             array_prune_fn              
+-----------------------------------------
+ Seq Scan on array_prune_t0 array_prune
+   Filter: (id = ANY ('{1}'::integer[]))
+(2 rows)
+
+SELECT array_prune_fn('= ANY', 'ARRAY[1,2]');  -- prune one partition
+              array_prune_fn               
+-------------------------------------------
+ Seq Scan on array_prune_t0 array_prune
+   Filter: (id = ANY ('{1,2}'::integer[]))
+(2 rows)
+
+SELECT array_prune_fn('= ANY', 'ARRAY[1,2,3]');  -- no pruning
+                  array_prune_fn                   
+---------------------------------------------------
+ Append
+   ->  Seq Scan on array_prune_t0 array_prune_1
+         Filter: (id = ANY ('{1,2,3}'::integer[]))
+   ->  Seq Scan on array_prune_t1 array_prune_2
+         Filter: (id = ANY ('{1,2,3}'::integer[]))
+(5 rows)
+
+SELECT array_prune_fn('= ANY', 'ARRAY[1, NULL]'); -- prune
+                array_prune_fn                
+----------------------------------------------
+ Seq Scan on array_prune_t0 array_prune
+   Filter: (id = ANY ('{1,NULL}'::integer[]))
+(2 rows)
+
+SELECT array_prune_fn('= ANY', 'ARRAY[3, NULL]'); -- prune
+                array_prune_fn                
+----------------------------------------------
+ Seq Scan on array_prune_t1 array_prune
+   Filter: (id = ANY ('{3,NULL}'::integer[]))
+(2 rows)
+
+SELECT array_prune_fn('= ANY', 'ARRAY[NULL, NULL]'); -- error
+ERROR:  operator does not exist: integer = text
+LINE 1: ...IN (COSTS OFF) SELECT * FROM array_prune WHERE id = ANY (ARR...
+                                                             ^
+HINT:  No operator matches the given name and argument types. You might need to add explicit type casts.
+QUERY:  EXPLAIN (COSTS OFF) SELECT * FROM array_prune WHERE id = ANY (ARRAY[NULL, NULL])
+CONTEXT:  PL/pgSQL function array_prune_fn(text,text) line 7 at FOR over EXECUTE statement
+-- Check case of explicit cast
+SELECT array_prune_fn('= ANY', 'ARRAY[1,2]::numeric[]');
+                       array_prune_fn                       
+------------------------------------------------------------
+ Append
+   ->  Seq Scan on array_prune_t0 array_prune_1
+         Filter: ((id)::numeric = ANY ('{1,2}'::numeric[]))
+   ->  Seq Scan on array_prune_t1 array_prune_2
+         Filter: ((id)::numeric = ANY ('{1,2}'::numeric[]))
+(5 rows)
+
+SELECT array_prune_fn('= ANY', 'ARRAY[1::bigint,2::int]'); -- conversion to bigint
+              array_prune_fn              
+------------------------------------------
+ Seq Scan on array_prune_t0 array_prune
+   Filter: (id = ANY ('{1,2}'::bigint[]))
+(2 rows)
+
+SELECT array_prune_fn('= ANY', 'ARRAY[1::bigint,2::numeric]'); -- conversion to numeric
+                       array_prune_fn                       
+------------------------------------------------------------
+ Append
+   ->  Seq Scan on array_prune_t0 array_prune_1
+         Filter: ((id)::numeric = ANY ('{1,2}'::numeric[]))
+   ->  Seq Scan on array_prune_t1 array_prune_2
+         Filter: ((id)::numeric = ANY ('{1,2}'::numeric[]))
+(5 rows)
+
+SELECT array_prune_fn('= ANY', 'ARRAY[1::bigint,2::text]'); -- Error. XXX: slightly different error in comparison with the static case
+ERROR:  ARRAY types bigint and text cannot be matched
+LINE 1: ...* FROM array_prune WHERE id = ANY (ARRAY[1::bigint,2::text])
+                                                              ^
+QUERY:  EXPLAIN (COSTS OFF) SELECT * FROM array_prune WHERE id = ANY (ARRAY[1::bigint,2::text])
+CONTEXT:  PL/pgSQL function array_prune_fn(text,text) line 7 at FOR over EXECUTE statement
+SELECT array_prune_fn('<> ANY', 'ARRAY[1]'); -- no pruning
+                 array_prune_fn                 
+------------------------------------------------
+ Append
+   ->  Seq Scan on array_prune_t0 array_prune_1
+         Filter: (id <> ANY ('{1}'::integer[]))
+   ->  Seq Scan on array_prune_t1 array_prune_2
+         Filter: (id <> ANY ('{1}'::integer[]))
+(5 rows)
+
+DROP TABLE IF EXISTS array_prune CASCADE;
+CREATE TABLE array_prune (id int)
+PARTITION BY RANGE(id);
+CREATE TABLE array_prune_t0
+ PARTITION OF array_prune FOR VALUES FROM (1) TO (10);
+CREATE TABLE array_prune_t1
+ PARTITION OF array_prune FOR VALUES FROM (10) TO (20);
+SELECT array_prune_fn('= ANY', 'ARRAY[10]'); -- prune
+              array_prune_fn              
+------------------------------------------
+ Seq Scan on array_prune_t1 array_prune
+   Filter: (id = ANY ('{10}'::integer[]))
+(2 rows)
+
+SELECT array_prune_fn('>= ANY', 'ARRAY[10]'); -- prune
+              array_prune_fn               
+-------------------------------------------
+ Seq Scan on array_prune_t1 array_prune
+   Filter: (id >= ANY ('{10}'::integer[]))
+(2 rows)
+
+SELECT array_prune_fn('>= ANY', 'ARRAY[9, 10]'); -- do not prune
+                  array_prune_fn                   
+---------------------------------------------------
+ Append
+   ->  Seq Scan on array_prune_t0 array_prune_1
+         Filter: (id >= ANY ('{9,10}'::integer[]))
+   ->  Seq Scan on array_prune_t1 array_prune_2
+         Filter: (id >= ANY ('{9,10}'::integer[]))
+(5 rows)
+
+DROP TABLE IF EXISTS array_prune CASCADE;
+CREATE TABLE array_prune (id int)
+PARTITION BY LIST(id);
+CREATE TABLE array_prune_t0
+ PARTITION OF array_prune FOR VALUES IN ('1');
+CREATE TABLE array_prune_t1
+ PARTITION OF array_prune FOR VALUES IN ('2');
+SELECT array_prune_fn('= ANY', 'ARRAY[1,1]'); -- prune second
+              array_prune_fn               
+-------------------------------------------
+ Seq Scan on array_prune_t0 array_prune
+   Filter: (id = ANY ('{1,1}'::integer[]))
+(2 rows)
+
+SELECT array_prune_fn('>= ANY', 'ARRAY[1,2]'); -- do not prune
+                  array_prune_fn                  
+--------------------------------------------------
+ Append
+   ->  Seq Scan on array_prune_t0 array_prune_1
+         Filter: (id >= ANY ('{1,2}'::integer[]))
+   ->  Seq Scan on array_prune_t1 array_prune_2
+         Filter: (id >= ANY ('{1,2}'::integer[]))
+(5 rows)
+
+SELECT array_prune_fn('<> ANY', 'ARRAY[1]'); -- prune second
+              array_prune_fn              
+------------------------------------------
+ Seq Scan on array_prune_t1 array_prune
+   Filter: (id <> ANY ('{1}'::integer[]))
+(2 rows)
+
+SELECT array_prune_fn('<> ALL', 'ARRAY[1,2]'); -- prune both
+      array_prune_fn      
+--------------------------
+ Result
+   One-Time Filter: false
+(2 rows)
+
diff --git a/src/test/regress/sql/inherit.sql b/src/test/regress/sql/inherit.sql
index 4e73c70495..d91a8024a5 100644
--- a/src/test/regress/sql/inherit.sql
+++ b/src/test/regress/sql/inherit.sql
@@ -1581,3 +1581,67 @@ select * from tuplesest_tab join
 
 drop table tuplesest_parted;
 drop table tuplesest_tab;
+
+--
+-- Test the cases for partition pruning by an expression like:
+-- partkey = ANY($1)
+--
+
+CREATE TABLE array_prune (id int)
+PARTITION BY HASH(id);
+
+CREATE TABLE array_prune_t0
+ PARTITION OF array_prune FOR VALUES WITH (modulus 2, remainder 0);
+CREATE TABLE array_prune_t1
+ PARTITION OF array_prune FOR VALUES WITH (modulus 2, remainder 1);
+
+CREATE FUNCTION array_prune_fn(oper text, arr text) RETURNS setof text
+LANGUAGE plpgsql AS $$
+DECLARE
+    line text;
+    query text;
+BEGIN
+  query := format('EXPLAIN (COSTS OFF) SELECT * FROM array_prune WHERE id %s (%s)', $1, $2);
+  FOR line IN EXECUTE query
+  LOOP
+    RETURN NEXT line;
+  END LOOP;
+END; $$;
+
+SELECT array_prune_fn('= ANY', 'ARRAY[1]'); -- prune one partition
+SELECT array_prune_fn('= ANY', 'ARRAY[1,2]');  -- prune one partition
+SELECT array_prune_fn('= ANY', 'ARRAY[1,2,3]');  -- no pruning
+SELECT array_prune_fn('= ANY', 'ARRAY[1, NULL]'); -- prune
+SELECT array_prune_fn('= ANY', 'ARRAY[3, NULL]'); -- prune
+SELECT array_prune_fn('= ANY', 'ARRAY[NULL, NULL]'); -- error
+-- Check case of explicit cast
+SELECT array_prune_fn('= ANY', 'ARRAY[1,2]::numeric[]');
+SELECT array_prune_fn('= ANY', 'ARRAY[1::bigint,2::int]'); -- conversion to bigint
+SELECT array_prune_fn('= ANY', 'ARRAY[1::bigint,2::numeric]'); -- conversion to numeric
+SELECT array_prune_fn('= ANY', 'ARRAY[1::bigint,2::text]'); -- Error. XXX: slightly different error in comparison with the static case
+
+SELECT array_prune_fn('<> ANY', 'ARRAY[1]'); -- no pruning
+
+DROP TABLE IF EXISTS array_prune CASCADE;
+CREATE TABLE array_prune (id int)
+PARTITION BY RANGE(id);
+CREATE TABLE array_prune_t0
+ PARTITION OF array_prune FOR VALUES FROM (1) TO (10);
+CREATE TABLE array_prune_t1
+ PARTITION OF array_prune FOR VALUES FROM (10) TO (20);
+
+SELECT array_prune_fn('= ANY', 'ARRAY[10]'); -- prune
+SELECT array_prune_fn('>= ANY', 'ARRAY[10]'); -- prune
+SELECT array_prune_fn('>= ANY', 'ARRAY[9, 10]'); -- do not prune
+
+DROP TABLE IF EXISTS array_prune CASCADE;
+CREATE TABLE array_prune (id int)
+PARTITION BY LIST(id);
+CREATE TABLE array_prune_t0
+ PARTITION OF array_prune FOR VALUES IN ('1');
+CREATE TABLE array_prune_t1
+ PARTITION OF array_prune FOR VALUES IN ('2');
+SELECT array_prune_fn('= ANY', 'ARRAY[1,1]'); -- prune second
+SELECT array_prune_fn('>= ANY', 'ARRAY[1,2]'); -- do not prune
+SELECT array_prune_fn('<> ANY', 'ARRAY[1]'); -- prune second
+SELECT array_prune_fn('<> ALL', 'ARRAY[1,2]'); -- prune both
-- 
2.39.5

Reply via email to