From 5d0a9f43f66bc2e9830df802048cf527f6ee4f16 Mon Sep 17 00:00:00 2001
From: Thomas Munro <thomas.munro@gmail.com>
Date: Tue, 19 Mar 2019 18:53:29 +1300
Subject: [PATCH] Add DNS SRV support for LDAP server discovery.

LDAP servers can be advertised on a network by registering DNS SRV
records for _ldap._tcp.<domain>.  The OpenLDAP command-line tools
know how to find servers via those records, if no server name is
provided by the user.  Teach PostgreSQL to follow the same convention
using non-standard extensions provided by OpenLDAP, where available.

Author: Thomas Munro
Reviewed-by: Daniel Gustafsson
Discussion: https://postgr.es/m/CAEepm=2hAnSfhdsd6vXsM6VZVN0br-FbAZ-O+Swk18S5HkCP=A@mail.gmail.com
---
 doc/src/sgml/client-auth.sgml |  21 ++++-
 src/backend/libpq/auth.c      | 145 ++++++++++++++++++++++++----------
 src/backend/libpq/hba.c       |   2 +
 3 files changed, 125 insertions(+), 43 deletions(-)

diff --git a/doc/src/sgml/client-auth.sgml b/doc/src/sgml/client-auth.sgml
index 411f1e1679..b6d44f2d66 100644
--- a/doc/src/sgml/client-auth.sgml
+++ b/doc/src/sgml/client-auth.sgml
@@ -1655,7 +1655,8 @@ ldap[s]://<replaceable>host</replaceable>[:<replaceable>port</replaceable>]/<rep
         </para>
 
         <para>
-         LDAP URLs are currently only supported with OpenLDAP, not on Windows.
+         LDAP URLs are currently only supported with
+         <productname>OpenLDAP</productname>, not on Windows.
         </para>
        </listitem>
       </varlistentry>
@@ -1678,6 +1679,15 @@ ldap[s]://<replaceable>host</replaceable>[:<replaceable>port</replaceable>]/<rep
     <literal>ldapsearchattribute=uid</literal>.
    </para>
 
+   <para>
+     If <productname>PostgreSQL</productname> was compiled with
+     <productname>OpenLDAP</productname> as the LDAP client library, the
+     <literal>ldapserver</literal> setting may be omitted.  In that case, a
+     list of hostnames and ports is looked up via RFC 2782 DNS service records.
+     The name <literal>_ldap._tcp.DOMAIN</literal> is looked up, where
+     <literal>DOMAIN</literal> is extracted from <literal>basedn</literal>.
+   </para>
+
    <para>
     Here is an example for a simple-bind LDAP configuration:
 <programlisting>
@@ -1723,6 +1733,15 @@ host ... ldap ldapserver=ldap.example.net ldapbasedn="dc=example, dc=net" ldapse
 </programlisting>
    </para>
 
+   <para>
+    Here is an example for a search+bind configuration that uses DNS SRV
+    discovery to find the hostname(s) and port(s) for the LDAP service for the
+    domain name <literal>example.net</literal>:
+<programlisting>
+host ... ldap ldapbasedn="dc=example,dc=net"
+</programlisting>
+   </para>
+
    <tip>
     <para>
      Since LDAP often uses commas and spaces to separate the different
diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index fb86e9e9d4..d3bbbdac6d 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -2369,44 +2369,87 @@ InitializeLDAPConnection(Port *port, LDAP **ldap)
 #else
 #ifdef HAVE_LDAP_INITIALIZE
 	{
-		const char *hostnames = port->hba->ldapserver;
-		char	   *uris = NULL;
+		StringInfoData uris;
+		char	   *hostlist = NULL;
+		char	   *p;
+		bool		append_port;
+
+		/* We'll build the list of scheme://hostname:port in a StringInfo */
+		initStringInfo(&uris);
 
 		/*
-		 * We have a space-separated list of hostnames.  Convert it
-		 * to a space-separated list of URIs.
+		 * If pg_hba.conf provided no hostnames, we can ask OpenLDAP to try to
+		 * find some by extracting a domain name from the base DN and looking
+		 * up DSN SRV records for _ldap._tcp.<domain>.  The same convention
+		 * is used by the OpenLDAP command line tools.
 		 */
+		if (!port->hba->ldapserver || port->hba->ldapserver[0] == '\0')
+		{
+			char	   *domain;
+
+			/* ou=blah,dc=foo,dc=bar -> foo.bar */
+			if (ldap_dn2domain(port->hba->ldapbasedn, &domain))
+			{
+				ereport(LOG,
+						(errmsg("could not extract domain name from basedn")));
+				return STATUS_ERROR;
+			}
+
+			/* Look up a list of LDAP server hosts and port numbers */
+			if (ldap_domain2hostlist(domain, &hostlist))
+			{
+				ereport(LOG,
+						(errmsg("LDAP authentication could not find DNS SRV records for \"%s\"",
+								domain),
+						(errhint("Set an LDAP server name explicitly."))));
+				ldap_memfree(domain);
+				return STATUS_ERROR;
+			}
+			ldap_memfree(domain);
+
+			/* We have a space-separated list of host:port entries */
+			p = hostlist;
+			append_port = false;
+		}
+		else
+		{
+			/* We have a space-separated list of hosts from pg_hba.conf */
+			p = port->hba->ldapserver;
+			append_port = true;
+		}
+
+		/* Build a space-separated list of full URIs */
 		do
 		{
-			char	   *hostname;
-			size_t		hostname_size;
-			char	   *new_uris;
-
-			/* Find the leading hostname. */
-			hostname_size = strcspn(hostnames, " ");
-			hostname = pnstrdup(hostnames, hostname_size);
-
-			/* Append a URI for this hostname. */
-			new_uris = psprintf("%s%s%s://%s:%d",
-								uris ? uris : "",
-								uris ? " " : "",
-								scheme,
-								hostname,
-								port->hba->ldapport);
-
-			pfree(hostname);
-			if (uris)
-				pfree(uris);
-			uris = new_uris;
-
-			/* Step over this hostname and any spaces. */
-			hostnames += hostname_size;
-			while (*hostnames == ' ')
-				++hostnames;
-		} while (*hostnames);
-
-		r = ldap_initialize(ldap, uris);
-		pfree(uris);
+			size_t		size;
+
+			/* Find the span of the next entry */
+			size = strcspn(p, " ");
+
+			/* Append a space separator if this isn't the first URI */
+			if (uris.len > 0)
+				appendStringInfoChar(&uris, ' ');
+
+			/* Append scheme://host:port */
+			appendStringInfoString(&uris, scheme);
+			appendStringInfoString(&uris, "://");
+			appendBinaryStringInfo(&uris, p, size);
+			if (append_port)
+				appendStringInfo(&uris, ":%d", port->hba->ldapport);
+
+			/* Step over this entry and any number of trailing spaces */
+			p += size;
+			while (*p == ' ')
+				++p;
+		} while (*p);
+
+		/* Free memory from OpenLDAP if we looked up SRV records */
+		if (hostlist)
+			ldap_memfree(hostlist);
+
+		/* Finally, try to connect using the URI list */
+		r = ldap_initialize(ldap, uris.data);
+		pfree(uris.data);
 		if (r != LDAP_SUCCESS)
 		{
 			ereport(LOG,
@@ -2552,13 +2595,31 @@ CheckLDAPAuth(Port *port)
 	LDAP	   *ldap;
 	int			r;
 	char	   *fulluser;
+	const char *server_name;
 
+#ifdef HAVE_LDAP_INITIALIZE
+	/* OpenLDAP allows empty hostname, if we have a basedn. */
+	if ((!port->hba->ldapserver || port->hba->ldapserver[0] == '\0') &&
+		(!port->hba->ldapbasedn || port->hba->ldapbasedn[0] == '\0'))
+	{
+		ereport(LOG,
+				(errmsg("LDAP server not specified, and no ldapbasedn")));
+		return STATUS_ERROR;
+	}
+#else
 	if (!port->hba->ldapserver || port->hba->ldapserver[0] == '\0')
 	{
 		ereport(LOG,
 				(errmsg("LDAP server not specified")));
 		return STATUS_ERROR;
 	}
+#endif
+
+	/*
+	 * If we're using SRV records, we don't have a server name so we'll
+	 * just show an empty string in error messages.
+	 */
+	server_name = port->hba->ldapserver ? port->hba->ldapserver : "";
 
 	if (port->hba->ldapport == 0)
 	{
@@ -2630,7 +2691,7 @@ CheckLDAPAuth(Port *port)
 			ereport(LOG,
 					(errmsg("could not perform initial LDAP bind for ldapbinddn \"%s\" on server \"%s\": %s",
 							port->hba->ldapbinddn ? port->hba->ldapbinddn : "",
-							port->hba->ldapserver,
+							server_name,
 							ldap_err2string(r)),
 					 errdetail_for_ldap(ldap)));
 			ldap_unbind(ldap);
@@ -2658,7 +2719,9 @@ CheckLDAPAuth(Port *port)
 		{
 			ereport(LOG,
 					(errmsg("could not search LDAP for filter \"%s\" on server \"%s\": %s",
-							filter, port->hba->ldapserver, ldap_err2string(r)),
+							filter,
+							server_name,
+							ldap_err2string(r)),
 					 errdetail_for_ldap(ldap)));
 			ldap_unbind(ldap);
 			pfree(passwd);
@@ -2673,14 +2736,13 @@ CheckLDAPAuth(Port *port)
 				ereport(LOG,
 						(errmsg("LDAP user \"%s\" does not exist", port->user_name),
 						 errdetail("LDAP search for filter \"%s\" on server \"%s\" returned no entries.",
-								   filter, port->hba->ldapserver)));
+								   filter, server_name)));
 			else
 				ereport(LOG,
 						(errmsg("LDAP user \"%s\" is not unique", port->user_name),
 						 errdetail_plural("LDAP search for filter \"%s\" on server \"%s\" returned %d entry.",
 										  "LDAP search for filter \"%s\" on server \"%s\" returned %d entries.",
-										  count,
-										  filter, port->hba->ldapserver, count)));
+										  count, filter, server_name, count)));
 
 			ldap_unbind(ldap);
 			pfree(passwd);
@@ -2698,8 +2760,7 @@ CheckLDAPAuth(Port *port)
 			(void) ldap_get_option(ldap, LDAP_OPT_ERROR_NUMBER, &error);
 			ereport(LOG,
 					(errmsg("could not get dn for the first entry matching \"%s\" on server \"%s\": %s",
-							filter, port->hba->ldapserver,
-							ldap_err2string(error)),
+							filter, server_name, ldap_err2string(error)),
 					 errdetail_for_ldap(ldap)));
 			ldap_unbind(ldap);
 			pfree(passwd);
@@ -2719,7 +2780,7 @@ CheckLDAPAuth(Port *port)
 		{
 			ereport(LOG,
 					(errmsg("could not unbind after searching for user \"%s\" on server \"%s\"",
-							fulluser, port->hba->ldapserver)));
+							fulluser, server_name)));
 			pfree(passwd);
 			pfree(fulluser);
 			return STATUS_ERROR;
@@ -2750,7 +2811,7 @@ CheckLDAPAuth(Port *port)
 	{
 		ereport(LOG,
 				(errmsg("LDAP login failed for user \"%s\" on server \"%s\": %s",
-						fulluser, port->hba->ldapserver, ldap_err2string(r)),
+						fulluser, server_name, ldap_err2string(r)),
 				 errdetail_for_ldap(ldap)));
 		ldap_unbind(ldap);
 		pfree(passwd);
diff --git a/src/backend/libpq/hba.c b/src/backend/libpq/hba.c
index 59de1b7639..9c4e81a0e9 100644
--- a/src/backend/libpq/hba.c
+++ b/src/backend/libpq/hba.c
@@ -1500,7 +1500,9 @@ parse_hba_line(TokenizedLine *tok_line, int elevel)
 	 */
 	if (parsedline->auth_method == uaLDAP)
 	{
+#ifndef HAVE_LDAP_INITIALIZE
 		MANDATORY_AUTH_ARG(parsedline->ldapserver, "ldapserver", "ldap");
+#endif
 
 		/*
 		 * LDAP can operate in two modes: either with a direct bind, using
-- 
2.21.0

