From 0e0d747c60d564991fc375f439649ac6f35f4578 Mon Sep 17 00:00:00 2001
From: Jelte Fennema <jelte.fennema@microsoft.com>
Date: Wed, 12 Jan 2022 09:52:05 +0100
Subject: [PATCH] Add non-blocking version of PQcancel

The existing PQcancel API is using blocking IO. This makes PQcancel
impossible to use in an event loop based codebase, without blocking the
event loop until the call returns.

This patch adds a new cancellation API to libpq which is called
PQcancelConnectionStart. This API can be used to send cancellations in a
non-blocking fashion. To do this it internally uses regular PGconn
connection establishment. This has as a downside that
PQcancelConnectionStart cannot be safely called from a  signal handler.

Luckily, this should be fine for most usages of this API. Since most
code that's using an event loop handles signals in that event loop as
well (as opposed to calling functions from the signal handler directly).

There are also a few advantages of this approach:
1. No need to add and maintain a second non-blocking connection
   establishment codepath.
2. Cancel connections benefit automatically from any improvements made
   to the normal connection establishment codepath. Examples of things
   that it currently gets for free currently are TLS support and
   keepalive settings.

This patch also includes a test for this new API (and also the already
existing cancellation APIs). The test can be easily run like this:

    cd src/test/modules/libpq_pipeline
    make && ./libpq_pipeline cancel

NOTE: I have not tested this with GSS for the moment. My expectation is
that using this new API with a GSS connection will result in a
CONNECTION_BAD status when calling PQcancelStatus. The reason for this
is that GSS reads will also need to communicate back that an EOF was
found, just like I've done for TLS reads and unencrypted reads. Since in
case of a cancel connection an EOF is actually expected, and should not
be treated as an error.
---
 src/interfaces/libpq/exports.txt              |   7 +
 src/interfaces/libpq/fe-connect.c             | 192 +++++++++++++++++-
 src/interfaces/libpq/fe-misc.c                |  15 +-
 src/interfaces/libpq/fe-secure-openssl.c      |   2 +-
 src/interfaces/libpq/fe-secure.c              |   3 +
 src/interfaces/libpq/libpq-fe.h               |  13 ++
 src/interfaces/libpq/libpq-int.h              |   8 +
 .../modules/libpq_pipeline/libpq_pipeline.c   | 115 ++++++++++-
 .../libpq_pipeline/t/001_libpq_pipeline.pl    |   2 +-
 9 files changed, 348 insertions(+), 9 deletions(-)

diff --git a/src/interfaces/libpq/exports.txt b/src/interfaces/libpq/exports.txt
index e8bcc88370..64364afeaf 100644
--- a/src/interfaces/libpq/exports.txt
+++ b/src/interfaces/libpq/exports.txt
@@ -186,3 +186,10 @@ PQpipelineStatus          183
 PQsetTraceFlags           184
 PQmblenBounded            185
 PQsendFlushRequest        186
+PQcancelConnect           187
+PQcancelConnectStart      188
+PQcancelConnectPoll       189
+PQcancelStatus            189
+PQcancelSocket            190
+PQcancelErrorMessage      191
+PQcancelFinish            192
diff --git a/src/interfaces/libpq/fe-connect.c b/src/interfaces/libpq/fe-connect.c
index 5fc16be849..347d32ad5f 100644
--- a/src/interfaces/libpq/fe-connect.c
+++ b/src/interfaces/libpq/fe-connect.c
@@ -378,6 +378,7 @@ static int	connectDBComplete(PGconn *conn);
 static PGPing internal_ping(PGconn *conn);
 static PGconn *makeEmptyPGconn(void);
 static bool fillPGconn(PGconn *conn, PQconninfoOption *connOptions);
+static bool copyPGconn(PGconn *srcConn, PGconn *dstConn);
 static void freePGconn(PGconn *conn);
 static void closePGconn(PGconn *conn);
 static void release_conn_addrinfo(PGconn *conn);
@@ -604,8 +605,11 @@ pqDropServerData(PGconn *conn)
 	if (conn->write_err_msg)
 		free(conn->write_err_msg);
 	conn->write_err_msg = NULL;
-	conn->be_pid = 0;
-	conn->be_key = 0;
+	if (!conn->cancelRequest)
+	{
+		conn->be_pid = 0;
+		conn->be_key = 0;
+	}
 }
 
 
@@ -737,6 +741,120 @@ PQping(const char *conninfo)
 	return ret;
 }
 
+/*
+ *		PQcancelConnectStart
+ *
+ * Asynchronously cancel a request on the given connection. This requires
+ * polling the returned PGconn to actually complete the cancellation of the
+ * request.
+ */
+PGcancelConn *
+PQcancelConnectStart(PGconn *conn)
+{
+	PGconn	   *cancelConn = makeEmptyPGconn();
+
+	if (cancelConn == NULL)
+		return NULL;
+
+	/*
+	 * Indicate that this connection is used to send a cancellation
+	 */
+	cancelConn->cancelRequest = true;
+
+	if (!copyPGconn(conn, cancelConn))
+		return (PGcancelConn *) cancelConn;
+
+	/*
+	 * Copy over information needed to cancel
+	 */
+	cancelConn->be_pid = conn->be_pid;
+	cancelConn->be_key = conn->be_key;
+
+	/*
+	 * Compute derived options
+	 */
+	if (!connectOptions2(cancelConn))
+		return (PGcancelConn *) cancelConn;
+
+	/*
+	 * Connect to the database
+	 */
+	if (!connectDBStart(cancelConn))
+	{
+		/* Just in case we failed to set it in connectDBStart */
+		cancelConn->status = CONNECTION_BAD;
+	}
+
+	return (PGcancelConn *) cancelConn;
+}
+
+/*
+ *		PQcancelConnect
+ *
+ * Cancel a request on the given connection
+ */
+PGcancelConn *
+PQcancelConnect(PGconn *conn)
+{
+	PGcancelConn *cancelConn = PQcancelConnectStart(conn);
+
+	if (cancelConn && cancelConn->conn.status != CONNECTION_BAD)
+		(void) connectDBComplete(&cancelConn->conn);
+
+	return cancelConn;
+}
+
+/*
+ *		PQcancelConnectPoll
+ *
+ * Poll a cancel connection. For usage details see the PQconnectPoll.
+ */
+PostgresPollingStatusType
+PQcancelConnectPoll(PGcancelConn * cancelConn)
+{
+	return PQconnectPoll((PGconn *) cancelConn);
+}
+
+/*
+ *		PQcancelStatus
+ *
+ * Get the status of a cancel connection.
+ */
+ConnStatusType
+PQcancelStatus(const PGcancelConn * cancelConn)
+{
+	return PQstatus((PGconn *) cancelConn);
+}
+
+/*
+ *		PQcancelSocket
+ *
+ * Get the socket of the cancel connection.
+ */
+int
+PQcancelSocket(const PGcancelConn * cancelConn)
+{
+	return PQsocket((PGconn *) cancelConn);
+}
+
+/*
+ *		PQcancelErrorMessage
+ *
+ * Get the socket of the cancel connection.
+ */
+char *
+PQcancelErrorMessage(const PGcancelConn * cancelConn)
+{
+	return PQerrorMessage((PGconn *) cancelConn);
+}
+
+void
+PQcancelFinish(PGcancelConn * cancelConn)
+{
+	PQfinish((PGconn *) cancelConn);
+}
+
+
 /*
  *		PQconnectStartParams
  *
@@ -914,6 +1032,46 @@ fillPGconn(PGconn *conn, PQconninfoOption *connOptions)
 	return true;
 }
 
+/*
+ * Copy over option values from srcConn to dstConn
+ *
+ * Don't put anything cute here --- intelligence should be in
+ * connectOptions2 ...
+ *
+ * Returns true on success. On failure, returns false and sets error message of
+ * dstConn.
+ */
+static bool
+copyPGconn(PGconn *srcConn, PGconn *dstConn)
+{
+	const internalPQconninfoOption *option;
+
+	/* copy over connection options */
+	for (option = PQconninfoOptions; option->keyword; option++)
+	{
+		if (option->connofs >= 0)
+		{
+			const char **tmp = (const char **) ((char *) srcConn + option->connofs);
+
+			if (*tmp)
+			{
+				char	  **dstConnmember = (char **) ((char *) dstConn + option->connofs);
+
+				if (*dstConnmember)
+					free(*dstConnmember);
+				*dstConnmember = strdup(*tmp);
+				if (*dstConnmember == NULL)
+				{
+					appendPQExpBufferStr(&dstConn->errorMessage,
+										 libpq_gettext("out of memory\n"));
+					return false;
+				}
+			}
+		}
+	}
+	return true;
+}
+
 /*
  *		connectOptions1
  *
@@ -2276,6 +2434,17 @@ PQconnectPoll(PGconn *conn)
 				/* Load waiting data */
 				int			n = pqReadData(conn);
 
+				if (n == -2 && conn->cancelRequest)
+				{
+					/*
+					 * This is the expected end state for cancel connections.
+					 * They are closed once the cancel is processed by the
+					 * server.
+					 */
+					conn->status = CONNECTION_CANCEL_FINISHED;
+					resetPQExpBuffer(&conn->errorMessage);
+					return PGRES_POLLING_OK;
+				}
 				if (n < 0)
 					goto error_return;
 				if (n == 0)
@@ -2950,6 +3119,25 @@ keep_going:						/* We will come back to here until there is
 				}
 #endif							/* USE_SSL */
 
+				if (conn->cancelRequest)
+				{
+					CancelRequestPacket cancelpacket;
+
+					packetlen = sizeof(cancelpacket);
+					cancelpacket.cancelRequestCode = (MsgType) pg_hton32(CANCEL_REQUEST_CODE);
+					cancelpacket.backendPID = pg_hton32(conn->be_pid);
+					cancelpacket.cancelAuthCode = pg_hton32(conn->be_key);
+					if (pqPacketSend(conn, 0, &cancelpacket, packetlen) != STATUS_OK)
+					{
+						appendPQExpBuffer(&conn->errorMessage,
+										  libpq_gettext("could not send cancel packet: %s\n"),
+										  SOCK_STRERROR(SOCK_ERRNO, sebuf, sizeof(sebuf)));
+						goto error_return;
+					}
+					conn->status = CONNECTION_AWAITING_RESPONSE;
+					return PGRES_POLLING_READING;
+				}
+
 				/*
 				 * Build the startup packet.
 				 */
diff --git a/src/interfaces/libpq/fe-misc.c b/src/interfaces/libpq/fe-misc.c
index 7fcfe08fd2..a95d63ffcd 100644
--- a/src/interfaces/libpq/fe-misc.c
+++ b/src/interfaces/libpq/fe-misc.c
@@ -558,8 +558,11 @@ pqPutMsgEnd(PGconn *conn)
  * Possible return values:
  *	 1: successfully loaded at least one more byte
  *	 0: no data is presently available, but no error detected
- *	-1: error detected (including EOF = connection closure);
+ *	-1: error detected (excluding EOF = connection closure);
  *		conn->errorMessage set
+ *	-2: EOF detected, connection is closed
+ *		conn->errorMessage set
+ *
  * NOTE: callers must not assume that pointers or indexes into conn->inBuffer
  * remain valid across this call!
  * ----------
@@ -642,7 +645,7 @@ retry3:
 
 			default:
 				/* pqsecure_read set the error message for us */
-				return -1;
+				return nread;
 		}
 	}
 	if (nread > 0)
@@ -737,7 +740,7 @@ retry4:
 
 			default:
 				/* pqsecure_read set the error message for us */
-				return -1;
+				return nread;
 		}
 	}
 	if (nread > 0)
@@ -755,13 +758,17 @@ definitelyEOF:
 						 libpq_gettext("server closed the connection unexpectedly\n"
 									   "\tThis probably means the server terminated abnormally\n"
 									   "\tbefore or while processing the request.\n"));
+	/* Do *not* drop any already-read data; caller still wants it */
+	pqDropConnection(conn, false);
+	conn->status = CONNECTION_BAD;	/* No more connection to backend */
+	return -2;
 
 	/* Come here if lower-level code already set a suitable errorMessage */
 definitelyFailed:
 	/* Do *not* drop any already-read data; caller still wants it */
 	pqDropConnection(conn, false);
 	conn->status = CONNECTION_BAD;	/* No more connection to backend */
-	return -1;
+	return nread < 0 ? nread : -1;
 }
 
 /*
diff --git a/src/interfaces/libpq/fe-secure-openssl.c b/src/interfaces/libpq/fe-secure-openssl.c
index 9f735ba437..3cd65fa276 100644
--- a/src/interfaces/libpq/fe-secure-openssl.c
+++ b/src/interfaces/libpq/fe-secure-openssl.c
@@ -252,7 +252,7 @@ rloop:
 			appendPQExpBufferStr(&conn->errorMessage,
 								 libpq_gettext("SSL connection has been closed unexpectedly\n"));
 			result_errno = ECONNRESET;
-			n = -1;
+			n = -2;
 			break;
 		default:
 			appendPQExpBuffer(&conn->errorMessage,
diff --git a/src/interfaces/libpq/fe-secure.c b/src/interfaces/libpq/fe-secure.c
index 0b998e254d..b2c66f47a5 100644
--- a/src/interfaces/libpq/fe-secure.c
+++ b/src/interfaces/libpq/fe-secure.c
@@ -201,6 +201,9 @@ pqsecure_close(PGconn *conn)
  * On failure, this function is responsible for appending a suitable message
  * to conn->errorMessage.  The caller must still inspect errno, but only
  * to determine whether to continue/retry after error.
+ *
+ * Returns -1 in case of failures, except in the case of clean connection
+ * closure then it returns -2.
  */
 ssize_t
 pqsecure_read(PGconn *conn, void *ptr, size_t len)
diff --git a/src/interfaces/libpq/libpq-fe.h b/src/interfaces/libpq/libpq-fe.h
index 20eb855abc..39aed5db3e 100644
--- a/src/interfaces/libpq/libpq-fe.h
+++ b/src/interfaces/libpq/libpq-fe.h
@@ -57,6 +57,7 @@ typedef enum
 {
 	CONNECTION_OK,
 	CONNECTION_BAD,
+	CONNECTION_CANCEL_FINISHED,
 	/* Non-blocking mode only below here */
 
 	/*
@@ -163,6 +164,11 @@ typedef enum
  */
 typedef struct pg_conn PGconn;
 
+/* PGcancelConn encapsulates a cancel connection to the backend.
+ * The contents of this struct are not supposed to be known to applications.
+ */
+typedef struct pg_cancel_conn PGcancelConn;
+
 /* PGresult encapsulates the result of a query (or more precisely, of a single
  * SQL command --- a query string given to PQsendQuery can contain multiple
  * commands and thus return multiple PGresult objects).
@@ -327,6 +333,13 @@ extern void PQfreeCancel(PGcancel *cancel);
 
 /* issue a cancel request */
 extern int	PQcancel(PGcancel *cancel, char *errbuf, int errbufsize);
+extern PGcancelConn * PQcancelConnectStart(PGconn *conn);
+extern PGcancelConn * PQcancelConnect(PGconn *conn);
+extern PostgresPollingStatusType PQcancelConnectPoll(PGcancelConn * cancelConn);
+extern ConnStatusType PQcancelStatus(const PGcancelConn * cancelConn);
+extern int	PQcancelSocket(const PGcancelConn * cancelConn);
+extern char *PQcancelErrorMessage(const PGcancelConn * cancelConn);
+extern void PQcancelFinish(PGcancelConn * cancelConn);
 
 /* backwards compatible version of PQcancel; not thread-safe */
 extern int	PQrequestCancel(PGconn *conn);
diff --git a/src/interfaces/libpq/libpq-int.h b/src/interfaces/libpq/libpq-int.h
index fcce13843e..8af3dd0ee7 100644
--- a/src/interfaces/libpq/libpq-int.h
+++ b/src/interfaces/libpq/libpq-int.h
@@ -394,6 +394,8 @@ struct pg_conn
 	char	   *ssl_max_protocol_version;	/* maximum TLS protocol version */
 	char	   *target_session_attrs;	/* desired session properties */
 
+	bool		cancelRequest;
+
 	/* Optional file to write trace info to */
 	FILE	   *Pfdebug;
 	int			traceFlags;
@@ -574,6 +576,11 @@ struct pg_conn
 	PQExpBufferData workBuffer; /* expansible string */
 };
 
+struct pg_cancel_conn
+{
+	PGconn		conn;
+};
+
 /* PGcancel stores all data necessary to cancel a connection. A copy of this
  * data is required to safely cancel a connection running on a different
  * thread.
@@ -691,6 +698,7 @@ extern int	pqPutInt(int value, size_t bytes, PGconn *conn);
 extern int	pqPutMsgStart(char msg_type, PGconn *conn);
 extern int	pqPutMsgEnd(PGconn *conn);
 extern int	pqReadData(PGconn *conn);
+extern int	pqReadDataOrEof(PGconn *conn);
 extern int	pqFlush(PGconn *conn);
 extern int	pqWait(int forRead, int forWrite, PGconn *conn);
 extern int	pqWaitTimed(int forRead, int forWrite, PGconn *conn,
diff --git a/src/test/modules/libpq_pipeline/libpq_pipeline.c b/src/test/modules/libpq_pipeline/libpq_pipeline.c
index 0ff563f59a..27188d43bb 100644
--- a/src/test/modules/libpq_pipeline/libpq_pipeline.c
+++ b/src/test/modules/libpq_pipeline/libpq_pipeline.c
@@ -86,6 +86,116 @@ pg_fatal_impl(int line, const char *fmt,...)
 	exit(1);
 }
 
+static void
+confirm_query_cancelled(PGconn *conn)
+{
+	PGresult   *res = NULL;
+
+	res = PQgetResult(conn);
+	if (res == NULL)
+		pg_fatal("PQgetResult returned null: %s",
+				 PQerrorMessage(conn));
+	if (PQresultStatus(res) != PGRES_FATAL_ERROR)
+		pg_fatal("query did not fail when it was expected");
+	if (strcmp(PQresultErrorField(res, PG_DIAG_SQLSTATE), "57014") != 0)
+		pg_fatal("query failed with a different error than cancellation: %s", PQerrorMessage(conn));
+	PQclear(res);
+	while (PQisBusy(conn))
+	{
+		PQconsumeInput(conn);
+	}
+}
+
+static void
+test_cancel(PGconn *conn)
+{
+	PGcancel   *cancel = NULL;
+	PGcancelConn *cancelConn = NULL;
+	char		errorbuf[256];
+
+	if (PQsetnonblocking(conn, 1) != 0)
+		pg_fatal("failed to set nonblocking mode: %s", PQerrorMessage(conn));
+
+	/* test PQrequestcancel */
+	if (PQsendQuery(conn, "SELECT pg_sleep(3)") != 1)
+		pg_fatal("failed to send query: %s", PQerrorMessage(conn));
+	PQrequestCancel(conn);
+	confirm_query_cancelled(conn);
+
+	/* test PQcancel */
+	if (PQsendQuery(conn, "SELECT pg_sleep(3)") != 1)
+		pg_fatal("failed to send query: %s", PQerrorMessage(conn));
+	cancel = PQgetCancel(conn);
+	PQcancel(cancel, errorbuf, sizeof(errorbuf));
+	confirm_query_cancelled(conn);
+
+	/* test PQcancelConnect */
+	if (PQsendQuery(conn, "SELECT pg_sleep(3)") != 1)
+		pg_fatal("failed to send query: %s", PQerrorMessage(conn));
+	cancelConn = PQcancelConnect(conn);
+	if (PQcancelStatus(cancelConn) != CONNECTION_CANCEL_FINISHED)
+		pg_fatal("unexpected cancel connection status: %s", PQcancelErrorMessage(cancelConn));
+	confirm_query_cancelled(conn);
+	PQcancelFinish(cancelConn);
+
+	/* test PQcancelConnectStart and then polling with PQcancelConnectPoll */
+	if (PQsendQuery(conn, "SELECT pg_sleep(3)") != 1)
+		pg_fatal("failed to send query: %s", PQerrorMessage(conn));
+	cancelConn = PQcancelConnectStart(conn);
+	if (PQcancelStatus(cancelConn) == CONNECTION_BAD)
+		pg_fatal("bad cancel connection: %s", PQcancelErrorMessage(cancelConn));
+	while (true)
+	{
+		struct timeval tv;
+		fd_set		input_mask;
+		fd_set		output_mask;
+		PostgresPollingStatusType pollres = PQcancelConnectPoll(cancelConn);
+		int			sock = PQcancelSocket(cancelConn);
+
+		if (pollres == PGRES_POLLING_OK)
+		{
+			break;
+		}
+
+		FD_ZERO(&input_mask);
+		FD_ZERO(&output_mask);
+		switch (pollres)
+		{
+			case PGRES_POLLING_READING:
+				pg_debug("polling for reads\n");
+				FD_SET(sock, &input_mask);
+				break;
+			case PGRES_POLLING_WRITING:
+				pg_debug("polling for writes\n");
+				FD_SET(sock, &output_mask);
+				break;
+			default:
+				pg_fatal("bad cancel connection: %s", PQcancelErrorMessage(cancelConn));
+		}
+
+		if (sock < 0)
+			pg_fatal("sock did not exist: %s", PQcancelErrorMessage(cancelConn));
+
+		tv.tv_sec = 3;
+		tv.tv_usec = 0;
+
+		while (true)
+		{
+			if (select(sock + 1, &input_mask, &output_mask, NULL, &tv) < 0)
+			{
+				if (errno == EINTR)
+					continue;
+				pg_fatal("select() failed: %m");
+			}
+			break;
+		}
+	}
+	if (PQcancelStatus(cancelConn) != CONNECTION_CANCEL_FINISHED)
+		pg_fatal("unexpected cancel connection status: %s", PQcancelErrorMessage(cancelConn));
+	confirm_query_cancelled(conn);
+	PQcancelFinish(cancelConn);
+}
+
 static void
 test_disallowed_in_pipeline(PGconn *conn)
 {
@@ -1555,6 +1665,7 @@ print_test_list(void)
 	printf("singlerow\n");
 	printf("transaction\n");
 	printf("uniqviol\n");
+	printf("cancel\n");
 }
 
 int
@@ -1642,7 +1753,9 @@ main(int argc, char **argv)
 						PQTRACE_SUPPRESS_TIMESTAMPS | PQTRACE_REGRESS_MODE);
 	}
 
-	if (strcmp(testname, "disallowed_in_pipeline") == 0)
+	if (strcmp(testname, "cancel") == 0)
+		test_cancel(conn);
+	else if (strcmp(testname, "disallowed_in_pipeline") == 0)
 		test_disallowed_in_pipeline(conn);
 	else if (strcmp(testname, "multi_pipelines") == 0)
 		test_multi_pipelines(conn);
diff --git a/src/test/modules/libpq_pipeline/t/001_libpq_pipeline.pl b/src/test/modules/libpq_pipeline/t/001_libpq_pipeline.pl
index 0c164dcaba..e0773543ae 100644
--- a/src/test/modules/libpq_pipeline/t/001_libpq_pipeline.pl
+++ b/src/test/modules/libpq_pipeline/t/001_libpq_pipeline.pl
@@ -26,7 +26,7 @@ for my $testname (@tests)
 	my @extraargs = ('-r', $numrows);
 	my $cmptrace = grep(/^$testname$/,
 		qw(simple_pipeline nosync multi_pipelines prepared singlerow
-		  pipeline_abort transaction disallowed_in_pipeline)) > 0;
+		  pipeline_abort transaction disallowed_in_pipeline cancel)) > 0;
 
 	# For a bunch of tests, generate a libpq trace file too.
 	my $traceout = "$PostgreSQL::Test::Utils::tmp_check/traces/$testname.trace";
-- 
2.17.1

