From 2f8c9faa7705ec13c2e140045faf393a8cc6928a Mon Sep 17 00:00:00 2001
From: Shlok Kyal <shlok.kyal.oss@gmail.com>
Date: Fri, 12 Jan 2024 15:55:30 +0530
Subject: [PATCH v20240117 2/2] Address some comments proposed on -hackers

This patch contains below changes.

* Add Timeout option and default timeout while waiting the recovery
* Restrict the target to be a standby node
* Reject when the --subscriber-conninfo specifies non-local server
---
 src/bin/pg_basebackup/pg_subscriber.c        | 153 +++++++++++++++----
 src/bin/pg_basebackup/t/040_pg_subscriber.pl |   8 +
 2 files changed, 133 insertions(+), 28 deletions(-)

diff --git a/src/bin/pg_basebackup/pg_subscriber.c b/src/bin/pg_basebackup/pg_subscriber.c
index e998c29f9e..2414c0f7ed 100644
--- a/src/bin/pg_basebackup/pg_subscriber.c
+++ b/src/bin/pg_basebackup/pg_subscriber.c
@@ -28,6 +28,7 @@
 #include "fe_utils/recovery_gen.h"
 #include "fe_utils/simple_list.h"
 #include "getopt_long.h"
+#include "libpq/pqcomm.h"
 #include "utils/pidfile.h"
 
 #define	PGS_OUTPUT_DIR	"pg_subscriber_output.d"
@@ -75,9 +76,13 @@ static void create_subscription(PGconn *conn, LogicalRepInfo *dbinfo);
 static void drop_subscription(PGconn *conn, LogicalRepInfo *dbinfo);
 static void set_replication_progress(PGconn *conn, LogicalRepInfo *dbinfo, const char *lsn);
 static void enable_subscription(PGconn *conn, LogicalRepInfo *dbinfo);
+static void start_standby_server(char *server_start_log);
 
 #define	USEC_PER_SEC	1000000
-#define	WAIT_INTERVAL	1		/* 1 second */
+#define DEFAULT_WAIT	60
+#define WAITS_PER_SEC	10		/* should divide USEC_PER_SEC evenly */
+
+static int	wait_seconds = DEFAULT_WAIT;
 
 /* Options */
 static const char *progname;
@@ -222,6 +227,27 @@ get_base_conninfo(char *conninfo, char *dbname, const char *noderole)
 			continue;
 		}
 
+		/*
+		 * If the dbname is NULL (this means the conninfo is for the
+		 * subscriber), we also check that the connection string does not
+		 * specify the non-local server.
+		 */
+		if (!dbname &&
+			conn_opt->val != NULL &&
+			(strcmp(conn_opt->keyword, "host") == 0 ||
+			 strcmp(conn_opt->keyword, "hostaddr") == 0))
+		{
+			const char *value = conn_opt->val;
+
+			if (strlen(value) > 0 &&
+			/* check for 'local' host values */
+				(strcmp(value, "localhost") != 0 &&
+				 strcmp(value, "127.0.0.1") != 0 &&
+				 strcmp(value, "::1") != 0 && !is_unixsock_path(value)))
+				pg_fatal("--subscriber-conninfo must not be non-local connection: %s",
+						 value);
+		}
+
 		if (conn_opt->val != NULL && conn_opt->val[0] != '\0')
 		{
 			if (i > 0)
@@ -830,6 +856,9 @@ wait_for_end_recovery(const char *conninfo)
 	PGconn	   *conn;
 	PGresult   *res;
 	int			status = POSTMASTER_STILL_STARTING;
+	int			cnt;
+	int			rc;
+	char	   *pg_ctl_cmd;
 
 	pg_log_info("waiting the postmaster to reach the consistent state");
 
@@ -837,7 +866,7 @@ wait_for_end_recovery(const char *conninfo)
 	if (conn == NULL)
 		exit(1);
 
-	for (;;)
+	for (cnt = 0; cnt < wait_seconds * WAITS_PER_SEC; cnt++)
 	{
 		bool		in_recovery;
 
@@ -870,11 +899,25 @@ wait_for_end_recovery(const char *conninfo)
 		}
 
 		/* Keep waiting. */
-		pg_usleep(WAIT_INTERVAL * USEC_PER_SEC);
+		pg_usleep(USEC_PER_SEC / WAITS_PER_SEC);
 	}
 
 	disconnect_database(conn);
 
+	/*
+	 * if timeout is reached exit the pg_subscriber and stop the standby node
+	 */
+	if (cnt >= wait_seconds * WAITS_PER_SEC)
+	{
+		pg_log_error("recovery timed out");
+
+		pg_ctl_cmd = psprintf("\"%s\" stop -D \"%s\" -s", pg_ctl_path, subscriber_dir);
+		rc = system(pg_ctl_cmd);
+		pg_ctl_status(pg_ctl_cmd, rc, 0);
+
+		exit(1);
+	}
+
 	if (status == POSTMASTER_STILL_STARTING)
 	{
 		pg_log_error("server did not end recovery");
@@ -1203,6 +1246,39 @@ enable_subscription(PGconn *conn, LogicalRepInfo *dbinfo)
 	destroyPQExpBuffer(str);
 }
 
+static void
+start_standby_server(char *server_start_log)
+{
+	char		timebuf[128];
+	struct timeval time;
+	time_t		tt;
+	int			len;
+	int			rc;
+	char	   *pg_ctl_cmd;
+
+	if (server_start_log[0] == '\0')
+	{
+		/* append timestamp with ISO 8601 format. */
+		gettimeofday(&time, NULL);
+		tt = (time_t) time.tv_sec;
+		strftime(timebuf, sizeof(timebuf), "%Y%m%dT%H%M%S", localtime(&tt));
+		snprintf(timebuf + strlen(timebuf), sizeof(timebuf) - strlen(timebuf),
+				 ".%03d", (int) (time.tv_usec / 1000));
+
+
+		len = snprintf(server_start_log, MAXPGPATH, "%s/%s/server_start_%s.log", subscriber_dir, PGS_OUTPUT_DIR, timebuf);
+		if (len >= MAXPGPATH)
+		{
+			pg_log_error("log file path is too long");
+			exit(1);
+		}
+	}
+
+	pg_ctl_cmd = psprintf("\"%s\" start -D \"%s\" -s -l \"%s\"", pg_ctl_path, subscriber_dir, server_start_log);
+	rc = system(pg_ctl_cmd);
+	pg_ctl_status(pg_ctl_cmd, rc, 1);
+}
+
 int
 main(int argc, char **argv)
 {
@@ -1214,6 +1290,7 @@ main(int argc, char **argv)
 		{"publisher-conninfo", required_argument, NULL, 'P'},
 		{"subscriber-conninfo", required_argument, NULL, 'S'},
 		{"database", required_argument, NULL, 'd'},
+		{"timeout", required_argument, NULL, 't'},
 		{"dry-run", no_argument, NULL, 'n'},
 		{"verbose", no_argument, NULL, 'v'},
 		{NULL, 0, NULL, 0}
@@ -1226,11 +1303,7 @@ main(int argc, char **argv)
 	char	   *pg_ctl_cmd;
 
 	char	   *base_dir;
-	char	   *server_start_log;
-
-	char		timebuf[128];
-	struct timeval time;
-	time_t		tt;
+	char		server_start_log[MAXPGPATH] = {0};
 	int			len;
 
 	char	   *pub_base_conninfo = NULL;
@@ -1250,6 +1323,8 @@ main(int argc, char **argv)
 
 	int			i;
 
+	PGresult   *res;
+
 	pg_logging_init(argv[0]);
 	pg_logging_set_level(PG_LOG_WARNING);
 	progname = get_progname(argv[0]);
@@ -1286,7 +1361,7 @@ main(int argc, char **argv)
 	}
 #endif
 
-	while ((c = getopt_long(argc, argv, "D:P:S:d:nv",
+	while ((c = getopt_long(argc, argv, "D:P:S:d:t:nv",
 							long_options, &option_index)) != -1)
 	{
 		switch (c)
@@ -1308,6 +1383,9 @@ main(int argc, char **argv)
 					num_dbs++;
 				}
 				break;
+			case 't':
+				wait_seconds = atoi(optarg);
+				break;
 			case 'n':
 				dry_run = true;
 				break;
@@ -1443,6 +1521,43 @@ main(int argc, char **argv)
 	/* subscriber PID file. */
 	snprintf(pidfile, MAXPGPATH, "%s/postmaster.pid", subscriber_dir);
 
+	/*
+	 * Start the standby server if it not running
+	 */
+	if (stat(pidfile, &statbuf) != 0)
+		start_standby_server(server_start_log);
+
+	/*
+	 * Exit the pg_subscriber if the node is not a standby server.
+	 */
+	conn = connect_database(dbinfo[0].subconninfo);
+	if (conn == NULL)
+		exit(1);
+
+	res = PQexec(conn, "SELECT pg_catalog.pg_is_in_recovery()");
+
+	if (PQresultStatus(res) != PGRES_TUPLES_OK)
+	{
+		pg_log_error("could not obtain recovery progress");
+		exit(1);
+	}
+
+	if (PQntuples(res) != 1)
+	{
+		pg_log_error("unexpected result from pg_is_in_recovery function");
+		exit(1);
+	}
+
+	/* check if the server is in recovery */
+	if (strcmp(PQgetvalue(res, 0, 0), "t") != 0)
+	{
+		pg_log_error("pg_subscriber is supported only on standby server");
+		exit(1);
+	}
+
+	PQclear(res);
+	disconnect_database(conn);
+
 	/*
 	 * Stop the subscriber if it is a standby server. Before executing the
 	 * transformation steps, make sure the subscriber is not running because
@@ -1532,25 +1647,7 @@ main(int argc, char **argv)
 	 * Start subscriber and wait until accepting connections.
 	 */
 	pg_log_info("starting the subscriber");
-
-	/* append timestamp with ISO 8601 format. */
-	gettimeofday(&time, NULL);
-	tt = (time_t) time.tv_sec;
-	strftime(timebuf, sizeof(timebuf), "%Y%m%dT%H%M%S", localtime(&tt));
-	snprintf(timebuf + strlen(timebuf), sizeof(timebuf) - strlen(timebuf),
-			 ".%03d", (int) (time.tv_usec / 1000));
-
-	server_start_log = (char *) pg_malloc0(MAXPGPATH);
-	len = snprintf(server_start_log, MAXPGPATH, "%s/%s/server_start_%s.log", subscriber_dir, PGS_OUTPUT_DIR, timebuf);
-	if (len >= MAXPGPATH)
-	{
-		pg_log_error("log file path is too long");
-		exit(1);
-	}
-
-	pg_ctl_cmd = psprintf("\"%s\" start -D \"%s\" -s -l \"%s\"", pg_ctl_path, subscriber_dir, server_start_log);
-	rc = system(pg_ctl_cmd);
-	pg_ctl_status(pg_ctl_cmd, rc, 1);
+	start_standby_server(server_start_log);
 
 	/*
 	 * Waiting the subscriber to be promoted.
diff --git a/src/bin/pg_basebackup/t/040_pg_subscriber.pl b/src/bin/pg_basebackup/t/040_pg_subscriber.pl
index 4ebff76b2d..e653df174d 100644
--- a/src/bin/pg_basebackup/t/040_pg_subscriber.pl
+++ b/src/bin/pg_basebackup/t/040_pg_subscriber.pl
@@ -40,5 +40,13 @@ command_fails(
 		'--subscriber-conninfo', 'dbname=postgres'
 	],
 	'no database name specified');
+command_fails(
+	[
+		'pg_subscriber', '--verbose',
+		'--pgdata', $datadir,
+		'--publisher-conninfo', 'dbname=postgres',
+		'--subscriber-conninfo', 'host=192.0.2.1 dbname=postgres'
+	],
+	'subscriber connection string specnfied non-local server');
 
 done_testing();
-- 
2.43.0

