Module Name:    src
Committed By:   christos
Date:           Fri May  3 19:31:13 UTC 2013

Modified Files:
        src/lib/libc/net: getaddrinfo.c

Log Message:
PR/32373, PR/25827: Add SRV lookup in getaddrinfo(3)
Per DNS-SD (RFC 2782), but only enabled if AI_SRV is set.


To generate a diff of this commit:
cvs rdiff -u -r1.102 -r1.103 src/lib/libc/net/getaddrinfo.c

Please note that diffs are not public domain; they are subject to the
copyright notices on the relevant files.

Modified files:

Index: src/lib/libc/net/getaddrinfo.c
diff -u src/lib/libc/net/getaddrinfo.c:1.102 src/lib/libc/net/getaddrinfo.c:1.103
--- src/lib/libc/net/getaddrinfo.c:1.102	Fri May  3 15:24:52 2013
+++ src/lib/libc/net/getaddrinfo.c	Fri May  3 15:31:13 2013
@@ -1,4 +1,4 @@
-/*	$NetBSD: getaddrinfo.c,v 1.102 2013/05/03 19:24:52 christos Exp $ */
+/*	$NetBSD: getaddrinfo.c,v 1.103 2013/05/03 19:31:13 christos Exp $	*/
 /*	$KAME: getaddrinfo.c,v 1.29 2000/08/31 17:26:57 itojun Exp $	*/
 
 /*
@@ -55,7 +55,7 @@
 
 #include <sys/cdefs.h>
 #if defined(LIBC_SCCS) && !defined(lint)
-__RCSID("$NetBSD: getaddrinfo.c,v 1.102 2013/05/03 19:24:52 christos Exp $");
+__RCSID("$NetBSD: getaddrinfo.c,v 1.103 2013/05/03 19:31:13 christos Exp $");
 #endif /* LIBC_SCCS and not lint */
 
 #include "namespace.h"
@@ -191,6 +191,13 @@ struct res_target {
 	int n;			/* result length */
 };
 
+struct srvinfo {
+       struct srvinfo *next;
+       char name[MAXDNAME];
+       int port, pri, weight;
+};
+
+static int gai_srvok(const char *);
 static int str2number(const char *);
 static int explore_fqdn(const struct addrinfo *, const char *,
     const char *, struct addrinfo **, struct servent_data *);
@@ -217,6 +224,12 @@ static int ip6_str2scopeid(char *, struc
 static struct addrinfo *getanswer(const querybuf *, int, const char *, int,
     const struct addrinfo *);
 static void aisort(struct addrinfo *s, res_state res);
+static struct addrinfo * _dns_query(struct res_target *,
+    const struct addrinfo *, res_state, int);
+static struct addrinfo * _dns_srv_lookup(const char *, const char *,
+    const struct addrinfo *);
+static struct addrinfo * _dns_host_lookup(const char *,
+    const struct addrinfo *);
 static int _dns_getaddrinfo(void *, void *, va_list);
 static void _sethtent(FILE **);
 static void _endhtent(FILE **);
@@ -319,6 +332,58 @@ freeaddrinfo(struct addrinfo *ai)
 	} while (ai);
 }
 
+/*
+ * We don't want localization to affect us
+ */
+#define PERIOD '.'
+#define hyphenchar(c) ((c) == '-')
+#define periodchar(c) ((c) == PERIOD)
+#define underschar(c) ((c) == '_')
+#define alphachar(c) (((c) >= 'a' && (c) <= 'z') || ((c) >= 'A' && (c) <= 'Z'))
+#define digitchar(c) ((c) >= '0' && (c) <= '9')
+
+#define firstchar(c)  (alphachar(c) || digitchar(c) || underschar(c))
+#define lastchar(c)   (alphachar(c) || digitchar(c))
+#define middlechar(c) (lastchar(c) || hyphenchar(c))
+
+static int
+gai_srvok(const char *dn)
+{
+	int nch, pch, ch;
+
+	for (pch = PERIOD, nch = ch = *dn++; ch != '\0'; pch = ch, ch = nch) {
+		if (periodchar(ch))
+			continue;
+		if (periodchar(pch)) {
+			if (!firstchar(ch))
+				return 0;
+		} else if (periodchar(nch) || nch == '\0') {
+			if (!lastchar(ch))
+				return 0;
+		} else if (!middlechar(ch))
+			return 0;
+       }
+       return 1;
+}
+
+static in_port_t *
+getport(struct addrinfo *ai) {
+	static in_port_t p;
+
+	switch (ai->ai_family) {
+	case AF_INET:
+		return &((struct sockaddr_in *)(void *)ai->ai_addr)->sin_port;
+#ifdef INET6
+	case AF_INET6:
+		return &((struct sockaddr_in6 *)(void *)ai->ai_addr)->sin6_port;
+#endif
+	default:
+		p = 0;
+		/* XXX: abort()? */
+		return &p;
+	}
+}
+
 static int
 str2number(const char *p)
 {
@@ -589,7 +654,7 @@ explore_fqdn(const struct addrinfo *pai,
 		return 0;
 
 	switch (nsdispatch(&result, dtab, NSDB_HOSTS, "getaddrinfo",
-			default_dns_files, hostname, pai)) {
+	    default_dns_files, hostname, pai, servname)) {
 	case NS_TRYAGAIN:
 		error = EAI_AGAIN;
 		goto free;
@@ -602,6 +667,9 @@ explore_fqdn(const struct addrinfo *pai,
 	case NS_SUCCESS:
 		error = 0;
 		for (cur = result; cur; cur = cur->ai_next) {
+			/* Check for already filled port. */
+			if (*getport(cur))
+				continue;
 			GET_PORT(cur, servname, svd);
 			/* canonname should be filled already */
 		}
@@ -990,21 +1058,8 @@ get_port(const struct addrinfo *ai, cons
 		port = sp->s_port;
 	}
 
-	if (!matchonly) {
-		switch (ai->ai_family) {
-		case AF_INET:
-			((struct sockaddr_in *)(void *)
-			    ai->ai_addr)->sin_port = port;
-			break;
-#ifdef INET6
-		case AF_INET6:
-			((struct sockaddr_in6 *)(void *)
-			    ai->ai_addr)->sin6_port = port;
-			break;
-#endif
-		}
-	}
-
+	if (!matchonly)
+		*getport(__UNCONST(ai)) = port;
 	return 0;
 }
 
@@ -1107,7 +1162,7 @@ getanswer(const querybuf *answer, int an
     const struct addrinfo *pai)
 {
 	struct addrinfo sentinel, *cur;
-	struct addrinfo ai;
+	struct addrinfo ai, *aip;
 	const struct afd *afd;
 	char *canonname;
 	const HEADER *hp;
@@ -1120,6 +1175,8 @@ getanswer(const querybuf *answer, int an
 	char tbuf[MAXDNAME];
 	int (*name_ok) (const char *);
 	char hostbuf[8*1024];
+	int port, pri, weight;
+	struct srvinfo *srvlist, *srv, *csrv;
 
 	_DIAGASSERT(answer != NULL);
 	_DIAGASSERT(qname != NULL);
@@ -1136,6 +1193,9 @@ getanswer(const querybuf *answer, int an
 	case T_ANY:	/*use T_ANY only for T_A/T_AAAA lookup*/
 		name_ok = res_hnok;
 		break;
+	case T_SRV:
+		name_ok = gai_srvok;
+		break;
 	default:
 		return NULL;	/* XXX should be abort(); */
 	}
@@ -1175,6 +1235,7 @@ getanswer(const querybuf *answer, int an
 	}
 	haveanswer = 0;
 	had_error = 0;
+	srvlist = NULL;
 	while (ancount-- > 0 && cp < eom && !had_error) {
 		n = dn_expand(answer->buf, eom, cp, bp, (int)(ep - bp));
 		if ((n < 0) || !(*name_ok)(bp)) {
@@ -1277,17 +1338,116 @@ getanswer(const querybuf *answer, int an
 				cur = cur->ai_next;
 			cp += n;
 			break;
+		case T_SRV:
+			/* Add to SRV list. Insertion sort on priority. */
+			pri = _getshort(cp);
+			cp += INT16SZ;
+			weight = _getshort(cp);
+			cp += INT16SZ;
+			port = _getshort(cp);
+			cp += INT16SZ;
+			n = dn_expand(answer->buf, eom, cp, tbuf,
+			    (int)sizeof(tbuf));
+			if ((n < 0) || !res_hnok(tbuf)) {
+				had_error++;
+				continue;
+			}
+			cp += n;
+			if (strlen(tbuf) + 1 >= MAXDNAME) {
+				had_error++;
+				continue;
+			}
+			srv = malloc(sizeof(*srv));
+			if (!srv) {
+				had_error++;
+				continue;
+			}
+			strlcpy(srv->name, tbuf, sizeof(srv->name));
+			srv->pri = pri;
+			srv->weight = weight;
+			srv->port = port;
+			/* Weight 0 is sorted before other weights. */
+			if (!srvlist
+			    || srv->pri < srvlist->pri
+			    || (srv->pri == srvlist->pri &&
+			    (!srv->weight || srvlist->weight))) {
+				srv->next = srvlist;
+				srvlist = srv;
+			} else {
+				for (csrv = srvlist;
+				    csrv->next && csrv->next->pri <= srv->pri;
+				    csrv = csrv->next) {
+					if (csrv->next->pri == srv->pri
+					    && (!srv->weight ||
+					    csrv->next->weight))
+						break;
+				}
+				srv->next = csrv->next;
+				csrv->next = srv;
+			}
+			continue; /* Don't add to haveanswer yet. */
 		default:
 			abort();
 		}
 		if (!had_error)
 			haveanswer++;
 	}
+
+	if (srvlist) {
+		res_state res;
+		/*
+		 * Check for explicit rejection.
+		 */
+		if (!srvlist->next && !srvlist->name[0]) {
+			free(srvlist);
+			h_errno = HOST_NOT_FOUND;
+			return NULL;
+		}
+		res = __res_get_state();
+		if (res == NULL) {
+			h_errno = NETDB_INTERNAL;
+			return NULL;
+		}
+
+		while (srvlist) {
+			struct res_target q, q2;
+
+			srv = srvlist;
+			srvlist = srvlist->next;
+
+			/*
+			 * Since res_* doesn't give the additional
+			 * section, we always look up.
+			 */
+			memset(&q, 0, sizeof(q));
+			memset(&q2, 0, sizeof(q2));
+
+			q.name = srv->name;
+			q.qclass = C_IN;
+			q.qtype = T_AAAA;
+			q.next = &q2;
+			q2.name = srv->name;
+			q2.qclass = C_IN;
+			q2.qtype = T_A;
+
+			aip = _dns_query(&q, pai, res, 0);
+
+			if (aip != NULL) {
+				cur->ai_next = aip;
+				while (cur && cur->ai_next) {
+					cur = cur->ai_next;
+					*getport(cur) = htons(srv->port);
+					haveanswer++;
+				}
+			}
+			free(srv);
+		}
+		__res_put_state(res);
+	}
 	if (haveanswer) {
-		if (!canonname)
-			(void)get_canonname(pai, sentinel.ai_next, qname);
-		else
-			(void)get_canonname(pai, sentinel.ai_next, canonname);
+		if (!sentinel.ai_next->ai_canonname)
+		       (void)get_canonname(pai, sentinel.ai_next,
+			   canonname ? canonname : qname);
 		h_errno = NETDB_SUCCESS;
 		return sentinel.ai_next;
 	}
@@ -1327,117 +1487,251 @@ aisort(struct addrinfo *s, res_state res
 	s->ai_next = head.ai_next;
 }
 
+static struct addrinfo *
+_dns_query(struct res_target *q, const struct addrinfo *pai,
+    res_state res, int dosearch)
+{
+	struct res_target *q2 = q->next;
+ 	querybuf *buf, *buf2;
+	struct addrinfo sentinel, *cur, *ai;
+
+#ifdef DNS_DEBUG
+	struct res_target *iter;
+	for (iter = q; iter; iter = iter->next)
+		printf("Query type %d for %s\n", iter->qtype, iter->name);
+#endif
+
+ 	buf = malloc(sizeof(*buf));
+ 	if (buf == NULL) {
+ 		h_errno = NETDB_INTERNAL;
+		return NULL;
+ 	}
+ 	buf2 = malloc(sizeof(*buf2));
+ 	if (buf2 == NULL) {
+ 		free(buf);
+ 		h_errno = NETDB_INTERNAL;
+		return NULL;
+	}
+
+	memset(&sentinel, 0, sizeof(sentinel));
+	cur = &sentinel;
+
+	q->answer = buf->buf;
+	q->anslen = sizeof(buf->buf);
+	if (q2) {
+		q2->answer = buf2->buf;
+		q2->anslen = sizeof(buf2->buf);
+	}
+
+	if (dosearch) {
+		if (res_searchN(q->name, q, res) < 0)
+			goto out;
+	} else {
+		if (res_queryN(q->name, q, res) < 0)
+			goto out;
+	}
+
+	ai = getanswer(buf, q->n, q->name, q->qtype, pai);
+	if (ai) {
+		cur->ai_next = ai;
+		while (cur && cur->ai_next)
+			cur = cur->ai_next;
+	}
+	if (q2) {
+		ai = getanswer(buf2, q2->n, q2->name, q2->qtype, pai);
+		if (ai)
+			cur->ai_next = ai;
+ 	}
+	free(buf);
+	free(buf2);
+	return sentinel.ai_next;
+out:
+	free(buf);
+	free(buf2);
+	return NULL;
+}
+
 /*ARGSUSED*/
-static int
-_dns_getaddrinfo(void *rv, void *cb_data, va_list ap)
+static struct addrinfo *
+_dns_srv_lookup(const char *name, const char *servname,
+    const struct addrinfo *pai)
 {
-	struct addrinfo *ai;
-	querybuf *buf, *buf2;
-	const char *name;
-	const struct addrinfo *pai;
-	struct addrinfo sentinel, *cur;
-	struct res_target q, q2;
+	static const char * const srvprotos[] = { "tcp", "udp" };
+	static const int srvnottype[] = { SOCK_DGRAM, SOCK_STREAM };
+	static const int nsrvprotos = 2;
+	struct addrinfo sentinel, *cur, *ai;
+	struct servent *serv, sv;
+	struct servent_data svd;
+	struct res_target q;
 	res_state res;
+	char *tname;
+	int i;
 
-	name = va_arg(ap, char *);
-	pai = va_arg(ap, const struct addrinfo *);
+	res = __res_get_state();
+	if (res == NULL)
+		return NULL;
 
-	memset(&q, 0, sizeof(q));
-	memset(&q2, 0, sizeof(q2));
+	memset(&svd, 0, sizeof(svd));
 	memset(&sentinel, 0, sizeof(sentinel));
 	cur = &sentinel;
 
-	buf = malloc(sizeof(*buf));
-	if (buf == NULL) {
-		h_errno = NETDB_INTERNAL;
-		return NS_NOTFOUND;
-	}
-	buf2 = malloc(sizeof(*buf2));
-	if (buf2 == NULL) {
-		free(buf);
-		h_errno = NETDB_INTERNAL;
-		return NS_NOTFOUND;
+	/*
+	 * Iterate over supported SRV protocols.
+	 * (currently UDP and TCP only)
+	 */
+	for (i = 0; i < nsrvprotos; i++) {
+		/*
+		 * Check that the caller didn't specify a hint
+		 * which precludes this protocol.
+		 */
+		if (pai->ai_socktype == srvnottype[i])
+			continue;
+		/*
+		 * If the caller specified a port,
+		 * then lookup the database for the
+		 * official service name.
+		 */
+		serv = getservbyname_r(servname, srvprotos[i], &sv, &svd);
+		if (serv == NULL)
+			continue;
+
+		/*
+		 * Construct service DNS name.
+		 */
+		if (asprintf(&tname, "_%s._%s.%s", serv->s_name, serv->s_proto,
+		    name) < 0)
+			continue;
+
+		memset(&q, 0, sizeof(q));
+		q.name = tname;
+		q.qclass = C_IN;
+		q.qtype = T_SRV;
+
+		/*
+		 * Do SRV query.
+		 */
+		ai = _dns_query(&q, pai, res, 1);
+		if (ai) {
+			cur->ai_next = ai;
+			while (cur && cur->ai_next)
+				cur = cur->ai_next;
+		}
+		free(tname);
 	}
 
+	if (res->nsort)
+		aisort(&sentinel, res);
+
+	__res_put_state(res);
+
+	return sentinel.ai_next;
+}
+
+/*ARGSUSED*/
+static struct addrinfo *
+_dns_host_lookup(const char *name, const struct addrinfo *pai)
+{
+	struct res_target q, q2;
+	struct addrinfo sentinel, *ai;
+	res_state res;
+
+	res = __res_get_state();
+	if (res == NULL)
+		return NULL;
+
+	memset(&q, 0, sizeof(q2));
+	memset(&q2, 0, sizeof(q2));
+
 	switch (pai->ai_family) {
 	case AF_UNSPEC:
 		/* prefer IPv6 */
 		q.name = name;
 		q.qclass = C_IN;
 		q.qtype = T_AAAA;
-		q.answer = buf->buf;
-		q.anslen = sizeof(buf->buf);
 		q.next = &q2;
 		q2.name = name;
 		q2.qclass = C_IN;
 		q2.qtype = T_A;
-		q2.answer = buf2->buf;
-		q2.anslen = sizeof(buf2->buf);
 		break;
 	case AF_INET:
 		q.name = name;
 		q.qclass = C_IN;
 		q.qtype = T_A;
-		q.answer = buf->buf;
-		q.anslen = sizeof(buf->buf);
 		break;
 	case AF_INET6:
 		q.name = name;
 		q.qclass = C_IN;
 		q.qtype = T_AAAA;
-		q.answer = buf->buf;
-		q.anslen = sizeof(buf->buf);
 		break;
 	default:
-		free(buf);
-		free(buf2);
-		return NS_UNAVAIL;
+		h_errno = NETDB_INTERNAL;
+		return NULL;
 	}
 
-	res = __res_get_state();
-	if (res == NULL) {
-		free(buf);
-		free(buf2);
-		return NS_NOTFOUND;
-	}
+	ai = _dns_query(&q, pai, res, 1);
 
-	if (res_searchN(name, &q, res) < 0) {
-		__res_put_state(res);
-		free(buf);
-		free(buf2);
-		return NS_NOTFOUND;
-	}
-	ai = getanswer(buf, q.n, q.name, q.qtype, pai);
-	if (ai) {
-		cur->ai_next = ai;
-		while (cur && cur->ai_next)
-			cur = cur->ai_next;
-	}
-	if (q.next) {
-		ai = getanswer(buf2, q2.n, q2.name, q2.qtype, pai);
-		if (ai)
-			cur->ai_next = ai;
-	}
-	free(buf);
-	free(buf2);
-	if (sentinel.ai_next == NULL) {
-		__res_put_state(res);
-		switch (h_errno) {
-		case HOST_NOT_FOUND:
-			return NS_NOTFOUND;
-		case TRY_AGAIN:
-			return NS_TRYAGAIN;
-		default:
-			return NS_UNAVAIL;
-		}
-	}
+	memset(&sentinel, 0, sizeof(sentinel));
+	sentinel.ai_next = ai;
 
-	if (res->nsort)
+	if (ai != NULL && res->nsort)
 		aisort(&sentinel, res);
 
 	__res_put_state(res);
 
-	*((struct addrinfo **)rv) = sentinel.ai_next;
+	return sentinel.ai_next;
+}
+
+/*ARGSUSED*/
+static int
+_dns_getaddrinfo(void *rv, void *cb_data, va_list ap)
+{
+	struct addrinfo *ai = NULL;
+	const char *name, *servname;
+	const struct addrinfo *pai;
+
+	name = va_arg(ap, char *);
+	pai = va_arg(ap, const struct addrinfo *);
+	servname = va_arg(ap, char *);
+
+	/*
+	 * Try doing SRV lookup on service first.
+	 */
+	if (servname
+#ifdef AI_SRV
+	    && (pai->ai_flags & AI_SRV)
+#endif
+	    && !(pai->ai_flags & AI_NUMERICSERV)
+	    && str2number(servname) == -1) {
+
+#ifdef DNS_DEBUG
+		printf("%s: try SRV lookup\n", __func__);
+#endif
+		ai = _dns_srv_lookup(name, servname, pai);
+	}
+
+	/*
+	 * Do lookup on name.
+	 */
+	if (ai == NULL) {
+
+#ifdef DNS_DEBUG
+		printf("%s: try HOST lookup\n", __func__);
+#endif
+		ai = _dns_host_lookup(name, pai);
+
+		if (ai == NULL) {
+			switch (h_errno) {
+			case HOST_NOT_FOUND:
+				return NS_NOTFOUND;
+			case TRY_AGAIN:
+				return NS_TRYAGAIN;
+			default:
+				return NS_UNAVAIL;
+			}
+		}
+	}
+
+	*((struct addrinfo **)rv) = ai;
 	return NS_SUCCESS;
 }
 
@@ -1996,3 +2290,33 @@ res_querydomainN(const char *name, const
 	}
 	return res_queryN(longname, target, res);
 }
+
+#ifdef TEST
+int
+main(int argc, char *argv[]) {
+	struct addrinfo *ai, *sai;
+	int i, e;
+	char buf[1024];
+
+	for (i = 1; i < argc; i++) {
+		if ((e = getaddrinfo(argv[i], NULL, NULL, &sai)) != 0)
+			warnx("%s: %s", argv[i], gai_strerror(e));
+		for (ai = sai; ai; ai = ai->ai_next) {
+			sockaddr_snprintf(buf, sizeof(buf), "%a", ai->ai_addr);
+             		printf("flags=0x%x family=%d socktype=%d protocol=%d "
+			    "addrlen=%zu addr=%s canonname=%s next=%p\n",
+			    ai->ai_flags,
+             		    ai->ai_family,
+             		    ai->ai_socktype,
+             		    ai->ai_protocol,
+             		    (size_t)ai->ai_addrlen,
+			    buf,
+			    ai->ai_canonname,
+			    ai->ai_next);
+		}
+		if (sai)
+			freeaddrinfo(sai);
+	}
+	return 0;
+}
+#endif

Reply via email to