From d7ab6346413998f6b37351cd9df224c678503144 Mon Sep 17 00:00:00 2001
From: Matheus Alcantara <mths.dev@pm.me>
Date: Mon, 20 Jan 2025 15:25:51 -0300
Subject: [PATCH v2 1/2] dblink: refactor get connection routines

Refactor dblink_get_conn and dblink_connect to move the logic of
actually opening the connection to the new connect_pg_server function
which them can be re-used on both functions.

This is a pre-work for a next commit that will add support for scram
pass-through authentication to dblink which will be able to implement
most of the logic into the connect_pg_server function which now already
have all necessary data information.
---
 contrib/dblink/dblink.c | 229 +++++++++++++++++++++-------------------
 1 file changed, 118 insertions(+), 111 deletions(-)

diff --git a/contrib/dblink/dblink.c b/contrib/dblink/dblink.c
index bed2dee3d72..e02c0f4d730 100644
--- a/contrib/dblink/dblink.c
+++ b/contrib/dblink/dblink.c
@@ -117,7 +117,7 @@ static bool dblink_connstr_has_pw(const char *connstr);
 static void dblink_security_check(PGconn *conn, remoteConn *rconn, const char *connstr);
 static void dblink_res_error(PGconn *conn, const char *conname, PGresult *res,
 							 bool fail, const char *fmt,...) pg_attribute_printf(5, 6);
-static char *get_connect_string(const char *servername);
+static char *get_connect_string(ForeignServer *foreign_server, UserMapping *user_mapping);
 static char *escape_param_str(const char *str);
 static void validate_pkattnums(Relation rel,
 							   int2vector *pkattnums_arg, int32 pknumatts_arg,
@@ -127,13 +127,14 @@ static bool is_valid_dblink_option(const PQconninfoOption *options,
 static int	applyRemoteGucs(PGconn *conn);
 static void restoreLocalGucs(int nestlevel);
 
+static PGconn *connect_pg_server(char *conname_or_str, remoteConn *rconn);
+
 /* Global */
 static remoteConn *pconn = NULL;
 static HTAB *remoteConnHash = NULL;
 
 /* custom wait event values, retrieved from shared memory */
 static uint32 dblink_we_connect = 0;
-static uint32 dblink_we_get_conn = 0;
 static uint32 dblink_we_get_result = 0;
 
 /*
@@ -201,33 +202,7 @@ dblink_get_conn(char *conname_or_str,
 	}
 	else
 	{
-		const char *connstr;
-
-		connstr = get_connect_string(conname_or_str);
-		if (connstr == NULL)
-			connstr = conname_or_str;
-		dblink_connstr_check(connstr);
-
-		/* first time, allocate or get the custom wait event */
-		if (dblink_we_get_conn == 0)
-			dblink_we_get_conn = WaitEventExtensionNew("DblinkGetConnect");
-
-		/* OK to make connection */
-		conn = libpqsrv_connect(connstr, dblink_we_get_conn);
-
-		if (PQstatus(conn) == CONNECTION_BAD)
-		{
-			char	   *msg = pchomp(PQerrorMessage(conn));
-
-			libpqsrv_disconnect(conn);
-			ereport(ERROR,
-					(errcode(ERRCODE_SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION),
-					 errmsg("could not establish connection"),
-					 errdetail_internal("%s", msg)));
-		}
-		dblink_security_check(conn, rconn, connstr);
-		if (PQclientEncoding(conn) != GetDatabaseEncoding())
-			PQsetClientEncoding(conn, GetDatabaseEncodingName());
+		conn = connect_pg_server(conname_or_str, rconn);
 		freeconn = true;
 		conname = NULL;
 	}
@@ -272,9 +247,7 @@ Datum
 dblink_connect(PG_FUNCTION_ARGS)
 {
 	char	   *conname_or_str = NULL;
-	char	   *connstr = NULL;
 	char	   *connname = NULL;
-	char	   *msg;
 	PGconn	   *conn = NULL;
 	remoteConn *rconn = NULL;
 
@@ -297,40 +270,7 @@ dblink_connect(PG_FUNCTION_ARGS)
 		rconn->newXactForCursor = false;
 	}
 
-	/* first check for valid foreign data server */
-	connstr = get_connect_string(conname_or_str);
-	if (connstr == NULL)
-		connstr = conname_or_str;
-
-	/* check password in connection string if not superuser */
-	dblink_connstr_check(connstr);
-
-	/* first time, allocate or get the custom wait event */
-	if (dblink_we_connect == 0)
-		dblink_we_connect = WaitEventExtensionNew("DblinkConnect");
-
-	/* OK to make connection */
-	conn = libpqsrv_connect(connstr, dblink_we_connect);
-
-	if (PQstatus(conn) == CONNECTION_BAD)
-	{
-		msg = pchomp(PQerrorMessage(conn));
-		libpqsrv_disconnect(conn);
-		if (rconn)
-			pfree(rconn);
-
-		ereport(ERROR,
-				(errcode(ERRCODE_SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION),
-				 errmsg("could not establish connection"),
-				 errdetail_internal("%s", msg)));
-	}
-
-	/* check password actually used if not superuser */
-	dblink_security_check(conn, rconn, connstr);
-
-	/* attempt to set client encoding to match server encoding, if needed */
-	if (PQclientEncoding(conn) != GetDatabaseEncoding())
-		PQsetClientEncoding(conn, GetDatabaseEncodingName());
+	conn = connect_pg_server(conname_or_str, rconn);
 
 	if (connname)
 	{
@@ -2784,15 +2724,17 @@ dblink_res_error(PGconn *conn, const char *conname, PGresult *res,
  * Obtain connection string for a foreign server
  */
 static char *
-get_connect_string(const char *servername)
+get_connect_string(ForeignServer *foreign_server, UserMapping *user_mapping)
 {
-	ForeignServer *foreign_server = NULL;
-	UserMapping *user_mapping;
 	ListCell   *cell;
 	StringInfoData buf;
 	ForeignDataWrapper *fdw;
 	AclResult	aclresult;
-	char	   *srvname;
+
+	/* first gather the server connstr options */
+	Oid			serverid = foreign_server->serverid;
+	Oid			fdwid = foreign_server->fdwid;
+	Oid			userid = GetUserId();
 
 	static const PQconninfoOption *options = NULL;
 
@@ -2815,57 +2757,42 @@ get_connect_string(const char *servername)
 					 errdetail("Could not get libpq's default connection options.")));
 	}
 
-	/* first gather the server connstr options */
-	srvname = pstrdup(servername);
-	truncate_identifier(srvname, strlen(srvname), false);
-	foreign_server = GetForeignServerByName(srvname, true);
-
-	if (foreign_server)
-	{
-		Oid			serverid = foreign_server->serverid;
-		Oid			fdwid = foreign_server->fdwid;
-		Oid			userid = GetUserId();
-
-		user_mapping = GetUserMapping(userid, serverid);
-		fdw = GetForeignDataWrapper(fdwid);
-
-		/* Check permissions, user must have usage on the server. */
-		aclresult = object_aclcheck(ForeignServerRelationId, serverid, userid, ACL_USAGE);
-		if (aclresult != ACLCHECK_OK)
-			aclcheck_error(aclresult, OBJECT_FOREIGN_SERVER, foreign_server->servername);
+	fdw = GetForeignDataWrapper(fdwid);
 
-		foreach(cell, fdw->options)
-		{
-			DefElem    *def = lfirst(cell);
+	/* Check permissions, user must have usage on the server. */
+	aclresult = object_aclcheck(ForeignServerRelationId, serverid, userid, ACL_USAGE);
+	if (aclresult != ACLCHECK_OK)
+		aclcheck_error(aclresult, OBJECT_FOREIGN_SERVER, foreign_server->servername);
 
-			if (is_valid_dblink_option(options, def->defname, ForeignDataWrapperRelationId))
-				appendStringInfo(&buf, "%s='%s' ", def->defname,
-								 escape_param_str(strVal(def->arg)));
-		}
+	foreach(cell, fdw->options)
+	{
+		DefElem    *def = lfirst(cell);
 
-		foreach(cell, foreign_server->options)
-		{
-			DefElem    *def = lfirst(cell);
+		if (is_valid_dblink_option(options, def->defname, ForeignDataWrapperRelationId))
+			appendStringInfo(&buf, "%s='%s' ", def->defname,
+							 escape_param_str(strVal(def->arg)));
+	}
 
-			if (is_valid_dblink_option(options, def->defname, ForeignServerRelationId))
-				appendStringInfo(&buf, "%s='%s' ", def->defname,
-								 escape_param_str(strVal(def->arg)));
-		}
+	foreach(cell, foreign_server->options)
+	{
+		DefElem    *def = lfirst(cell);
 
-		foreach(cell, user_mapping->options)
-		{
+		if (is_valid_dblink_option(options, def->defname, ForeignServerRelationId))
+			appendStringInfo(&buf, "%s='%s' ", def->defname,
+							 escape_param_str(strVal(def->arg)));
+	}
 
-			DefElem    *def = lfirst(cell);
+	foreach(cell, user_mapping->options)
+	{
 
-			if (is_valid_dblink_option(options, def->defname, UserMappingRelationId))
-				appendStringInfo(&buf, "%s='%s' ", def->defname,
-								 escape_param_str(strVal(def->arg)));
-		}
+		DefElem    *def = lfirst(cell);
 
-		return buf.data;
+		if (is_valid_dblink_option(options, def->defname, UserMappingRelationId))
+			appendStringInfo(&buf, "%s='%s' ", def->defname,
+							 escape_param_str(strVal(def->arg)));
 	}
-	else
-		return NULL;
+
+	return buf.data;
 }
 
 /*
@@ -3087,3 +3014,83 @@ restoreLocalGucs(int nestlevel)
 	if (nestlevel > 0)
 		AtEOXact_GUC(true, nestlevel);
 }
+
+/*
+ * Connect to remote server. If connstr_or_srvname maps to a foreign server,
+ * the associated properties and user mapping properties is also used to open
+ * the connection. Otherwise a connection will be open using the raw
+ * connstr_or_srvname value.
+ */
+static PGconn *
+connect_pg_server(char *connstr_or_srvname, remoteConn *rconn)
+{
+	PGconn	   *conn;
+	ForeignServer *foreign_server = NULL;
+	const char *connstr;
+	char	   *srvname;
+	Oid			serverid;
+	UserMapping *user_mapping;
+	Oid			userid = GetUserId();
+
+	static const PQconninfoOption *options = NULL;
+
+	/*
+	 * Get list of valid libpq options.
+	 *
+	 * To avoid unnecessary work, we get the list once and use it throughout
+	 * the lifetime of this backend process.  We don't need to care about
+	 * memory context issues, because PQconndefaults allocates with malloc.
+	 */
+	if (!options)
+	{
+		options = PQconndefaults();
+		if (!options)			/* assume reason for failure is OOM */
+			ereport(ERROR,
+					(errcode(ERRCODE_FDW_OUT_OF_MEMORY),
+					 errmsg("out of memory"),
+					 errdetail("Could not get libpq's default connection options.")));
+	}
+
+	/* first gather the server connstr options */
+	srvname = pstrdup(connstr_or_srvname);
+	truncate_identifier(srvname, strlen(srvname), false);
+	foreign_server = GetForeignServerByName(srvname, true);
+
+	if (foreign_server)
+	{
+		serverid = foreign_server->serverid;
+		user_mapping = GetUserMapping(userid, serverid);
+
+		connstr = get_connect_string(foreign_server, user_mapping);
+	}
+	else
+		connstr = connstr_or_srvname;
+
+	dblink_connstr_check(connstr);
+
+	/* first time, allocate or get the custom wait event */
+	if (dblink_we_connect == 0)
+		dblink_we_connect = WaitEventExtensionNew("DblinkConnectPgServer");
+
+	/* OK to make connection */
+	conn = libpqsrv_connect(connstr, dblink_we_connect);
+
+	if (!conn || PQstatus(conn) != CONNECTION_OK)
+	{
+		char	   *msg = pchomp(PQerrorMessage(conn));
+
+		libpqsrv_disconnect(conn);
+		ereport(ERROR,
+				(errcode(ERRCODE_SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION),
+				 errmsg("could not establish connection"),
+				 errdetail_internal("%s", msg)));
+	}
+
+	dblink_security_check(conn, rconn, connstr);
+
+	/* attempt to set client encoding to match server encoding, if needed */
+	if (PQclientEncoding(conn) != GetDatabaseEncoding())
+		PQsetClientEncoding(conn, GetDatabaseEncodingName());
+
+	return conn;
+}
-- 
2.39.5 (Apple Git-154)

