Hi everyone,

I've put together a patch for 6.0-stable that adds domain name
matching support to rebound(8).  The patch is quite rough at the
moment.

The config is as follows:

        match "local." 10.0.0.53
        match "." 8.8.8.8

Requests to foo.local. are sent over to 10.0.0.53, all other requests
go to 8.8.8.8.  In my implementation, the first match wins.

General drawbacks:

- rebound has to parse DNS requests.  I tried to keep the parsing code
  as small as possible to avoid security problems.

Drawbacks in current implementation:

- No caching for DNS requests over TCP.  I am planning to implement
  this via a unified cache that works for both UDP and TCP.
- non-blocking connect(2) support for TCP.  The original code handled
  that but I reworked it because I wanted to get it working first.

What do you think?

===================================================================
RCS file: /cvs/src/usr.sbin/rebound/rebound.c,v
retrieving revision 1.65
diff -u -p -r1.65 rebound.c
--- rebound.c   2 Jul 2016 17:09:09 -0000       1.65
+++ rebound.c   16 Sep 2016 12:29:39 -0000
@@ -37,6 +37,8 @@
 #include <getopt.h>
 #include <stdarg.h>
 
+#define LEN(x) (sizeof (x) / sizeof *(x))
+
 uint16_t randomid(void);
 
 static struct timespec now;
@@ -100,6 +102,13 @@ struct request {
 };
 static TAILQ_HEAD(, request) reqfifo;
 
+struct match {
+       char pat[256];
+       struct sockaddr_storage to;
+       TAILQ_ENTRY(match) entry;
+};
+static TAILQ_HEAD(, match) matches;
+
 static int conncount;
 static int connmax;
 static uint64_t conntotal;
@@ -215,10 +224,94 @@ servfail(int ud, uint16_t id, struct soc
        sendto(ud, &pkt, sizeof(pkt), 0, fromaddr, fromlen);
 }
 
+static size_t
+readn(int fd, void *buf, size_t n)
+{
+       size_t total = 0;
+       size_t r;
+
+       while (n > 0) {
+               r = read(fd, buf + total, n);
+               if (r == 0 || r == -1)
+                       return -1;
+               total += r;
+               n -= r;
+       }
+       return total;
+}
+
+static size_t
+writen(int fd, void *buf, size_t n)
+{
+       size_t total = 0;
+       size_t r;
+
+       while (n > 0) {
+               r = write(fd, buf + total, n);
+               if (r == 0 || r == -1)
+                       return -1;
+               total += r;
+               n -= r;
+       }
+       return total;
+}
+
+int
+parsedomain(uint8_t *buf, size_t buflen, char *host, size_t hostlen)
+{
+       uint8_t *bp = &buf[0], *be = &buf[buflen];
+       char *hp = &host[0], *he = &host[hostlen];
+
+       bp += sizeof(struct dnspacket);
+       if (bp >= be)
+               return -1;
+       for (;;) {
+               uint8_t len = *bp++;
+               if (len == 0)
+                       break;
+               if (bp + len >= be || hp + len >= he)
+                       return -1;
+               memcpy(hp, bp, len);
+               bp += len;
+               hp += len;
+               *hp++ = '.';
+               if (hp == he)
+                       return -1;
+       }
+       *hp = '\0';
+       return 0;
+}
+
+int
+matchreq(uint8_t *buf, size_t buflen, struct sockaddr_storage *to)
+{
+       char host[65536];
+       struct match *match;
+
+       /* XXX: check flags/qdcount? */
+       if (parsedomain(buf, buflen, host, sizeof(host)) == -1)
+               return -1;
+       TAILQ_FOREACH(match, &matches, entry) {
+               size_t hlen = strlen(host);
+               size_t glen = strlen(match->pat);
+               if (hlen < glen)
+                       continue;
+               if (strcmp(&host[hlen - glen], match->pat) == 0) {
+                       memcpy(to, &match->to, sizeof(*to));
+                       logmsg(LOG_DEBUG, "matched domain %s with %s",
+                              host, match->pat);
+                       /* first match wins */
+                       return 0;
+               }
+       }
+       return -1;
+}
+
 static struct request *
-newrequest(int ud, struct sockaddr *remoteaddr)
+newrequest(int ud)
 {
-       struct sockaddr from;
+       struct sockaddr_storage remoteaddr;
+       struct sockaddr from, *to;
        socklen_t fromlen;
        struct request *req;
        uint8_t buf[65536];
@@ -271,13 +364,17 @@ newrequest(int ud, struct sockaddr *remo
        }
        req->cacheent = hit;
 
-       req->s = socket(remoteaddr->sa_family, SOCK_DGRAM, 0);
+       if (matchreq(buf, r, &remoteaddr) == -1)
+               goto fail;
+       to = (struct sockaddr *)&remoteaddr;
+
+       req->s = socket(to->sa_family, SOCK_DGRAM, 0);
        if (req->s == -1)
                goto fail;
 
        TAILQ_INSERT_TAIL(&reqfifo, req, fifo);
 
-       if (connect(req->s, remoteaddr, remoteaddr->sa_len) == -1) {
+       if (connect(req->s, to, to->sa_len) == -1) {
                logmsg(LOG_NOTICE, "failed to connect (%d)", errno);
                if (errno == EADDRNOTAVAIL)
                        servfail(ud, req->clientid, &from, fromlen);
@@ -335,36 +432,18 @@ sendreply(int ud, struct request *req)
 }
 
 static struct request *
-tcpphasetwo(struct request *req)
-{
-       int error;
-       socklen_t len = sizeof(error);
-
-       req->tcp = 2;
-
-       if (getsockopt(req->s, SOL_SOCKET, SO_ERROR, &error, &len) == -1 ||
-           error != 0)
-               goto fail;
-       if (setsockopt(req->client, SOL_SOCKET, SO_SPLICE, &req->s,
-           sizeof(req->s)) == -1)
-               goto fail;
-       if (setsockopt(req->s, SOL_SOCKET, SO_SPLICE, &req->client,
-           sizeof(req->client)) == -1)
-               goto fail;
-
-       return req;
-
-fail:
-       freerequest(req);
-       return NULL;
-}
-
-static struct request *
-newtcprequest(int ld, struct sockaddr *remoteaddr)
+newtcprequest(int ld)
 {
+       struct sockaddr_storage remoteaddr;
+       struct sockaddr *to;
        struct request *req;
+       uint8_t buf[65536];
+       struct dnspacket *dnsreq;
+       uint16_t reqsize;
        int client;
 
+       dnsreq = (struct dnspacket *)&buf[2];
+
        client = accept(ld, NULL, 0);
        if (client == -1) {
                if (errno == ENFILE || errno == EMFILE)
@@ -372,6 +451,24 @@ newtcprequest(int ld, struct sockaddr *r
                return NULL;
        }
 
+       if (readn(client, &reqsize, sizeof(reqsize)) == -1) {
+               close(client);
+               return NULL;
+       }
+       if (reqsize > sizeof(buf) - 2) {
+               close(client);
+               return NULL;
+       }
+       memcpy(buf, &reqsize, sizeof(reqsize));
+
+       reqsize = ntohs(reqsize);
+       if (readn(client, &buf[2], reqsize) == -1) {
+               close(client);
+               return NULL;
+       }
+
+       /* XXX: unified cache handling for tcp/udp requests */
+
        if (!(req = calloc(1, sizeof(*req)))) {
                close(client);
                return NULL;
@@ -383,18 +480,31 @@ newtcprequest(int ld, struct sockaddr *r
        req->ts.tv_sec += 30;
        req->tcp = 1;
        req->client = client;
+       req->s = -1;
+
+       req->clientid = dnsreq->id;
+       req->reqid = randomid();
+       dnsreq->id = req->reqid;
 
-       req->s = socket(remoteaddr->sa_family, SOCK_STREAM | SOCK_NONBLOCK, 0);
+       if (matchreq(&buf[2], reqsize, &remoteaddr) == -1)
+               goto fail;
+       to = (struct sockaddr *)&remoteaddr;
+
+       req->s = socket(to->sa_family, SOCK_STREAM, 0);
        if (req->s == -1)
                goto fail;
 
        TAILQ_INSERT_TAIL(&reqfifo, req, fifo);
 
-       if (connect(req->s, remoteaddr, remoteaddr->sa_len) == -1) {
-               if (errno != EINPROGRESS)
-                       goto fail;
-       } else {
-               return tcpphasetwo(req);
+       /* XXX: should really use non-blocking connect */
+       if (connect(req->s, to, to->sa_len) == -1) {
+               logmsg(LOG_NOTICE, "failed to connect (%d)", errno);
+               goto fail;
+       }
+
+       if (writen(req->s, buf, reqsize + 2) == -1) {
+               logmsg(LOG_NOTICE, "failed to write (%d)", errno);
+               goto fail;
        }
 
        return req;
@@ -404,43 +514,133 @@ fail:
        return NULL;
 }
 
+static void
+sendtcpreply(struct request *req)
+{
+       uint8_t buf[65536];
+       struct dnspacket *resp;
+       uint16_t reqsize;
+
+       resp = (struct dnspacket *)&buf[2];
+
+       if (readn(req->s, &reqsize, sizeof(reqsize)) == -1)
+               return;
+       if (reqsize > sizeof(buf) - 2)
+               return;
+       memcpy(buf, &reqsize, sizeof(reqsize));
+
+       reqsize = ntohs(reqsize);
+       if (readn(req->s, &buf[2], reqsize) == -1)
+               return;
+       if (resp->id != req->reqid)
+               return;
+       resp->id = req->clientid;
+
+       if (writen(req->client, buf, reqsize + 2) == -1)
+               return;
+
+       /* XXX: cache handling */
+}
+
+static void
+free_matches(void)
+{
+       struct match *match, *tmp;
+
+       for (match = TAILQ_FIRST(&matches); match != NULL; match = tmp) {
+               tmp = TAILQ_NEXT(match, entry);
+               TAILQ_REMOVE(&matches, match, entry);
+               free(match);
+       }
+}
+
 static int
-readconfig(FILE *conf, struct sockaddr_storage *remoteaddr)
+readconfig(FILE *conf)
 {
+#define KEYWORDIDX     0
+#define PATTERNIDX     1
+#define NSIDX          2
+#define NTOKENS                3
+       char *tokens[NTOKENS], *p, *last;
        char buf[1024];
-       struct sockaddr_in *sin = (struct sockaddr_in *)remoteaddr;
-       struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)remoteaddr;
+       struct sockaddr_in *sin;
+       struct sockaddr_in6 *sin6;
+       struct match *match;
+       int i;
+
+       free_matches(); /* for SIGHUP */
+       while (fgets(buf, sizeof(buf), conf) != NULL) {
+               buf[strcspn(buf, "\n")] = '\0';
+
+               /* tokenize line */
+               for (i = 0, p = strtok_r(buf, " \t", &last);
+                    p != NULL;
+                    p = strtok_r(NULL, " \t", &last))
+                       if (i < LEN(tokens))
+                               tokens[i++] = p;
+               if (i != NTOKENS)
+                       goto fail;
 
-       if (fgets(buf, sizeof(buf), conf) == NULL)
-               return -1;
-       buf[strcspn(buf, "\n")] = '\0';
+               /* only recognize the match keyword so far */
+               if (strcmp(tokens[KEYWORDIDX], "match") != 0)
+                       goto fail;
 
-       memset(remoteaddr, 0, sizeof(*remoteaddr));
-       if (inet_pton(AF_INET, buf, &sin->sin_addr) == 1) {
-               sin->sin_len = sizeof(*sin);
-               sin->sin_family = AF_INET;
-               sin->sin_port = htons(53);
-               return AF_INET;
-       } else if (inet_pton(AF_INET6, buf, &sin6->sin6_addr) == 1) {
-               sin6->sin6_len = sizeof(*sin6);
-               sin6->sin6_family = AF_INET6;
-               sin6->sin6_port = htons(53);
-               return AF_INET6;
-       } else {
-               return -1;
+               match = malloc(sizeof(*match));
+               if (match == NULL)
+                       goto fail;
+
+               /* extract pattern */
+               for (i = 0, p = tokens[PATTERNIDX]; *p != '\0'; p++) {
+                       if (*p == '"')
+                               continue;
+                       if (i < LEN(match->pat) - 1)
+                               match->pat[i++] = *p;
+               }
+               if (i == 0) {
+                       /* empty pattern? bail */
+                       free(match);
+                       goto fail;
+               }
+               match->pat[i] = '\0';
+
+               memset(&match->to, 0, sizeof(match->to));
+               sin = (struct sockaddr_in *)&match->to;
+               sin6 = (struct sockaddr_in6 *)&match->to;
+               if (inet_pton(AF_INET, tokens[NSIDX], &sin->sin_addr) == 1) {
+                       sin->sin_len = sizeof(*sin);
+                       sin->sin_family = AF_INET;
+                       sin->sin_port = htons(53);
+               } else if (inet_pton(AF_INET6, tokens[NSIDX], &sin6->sin6_addr) 
== 1) {
+                       sin6->sin6_len = sizeof(*sin6);
+                       sin6->sin6_family = AF_INET6;
+                       sin6->sin6_port = htons(53);
+               } else {
+                       free(match);
+                       goto fail;
+               }
+
+               TAILQ_INSERT_TAIL(&matches, match, entry);
        }
+
+       /* we need at least one match rule */
+       if (TAILQ_EMPTY(&matches))
+               goto fail;
+
+       return 0;
+fail:
+       free_matches();
+       return -1;
 }
 
 static int
 launch(FILE *conf, int ud, int ld, int kq)
 {
-       struct sockaddr_storage remoteaddr;
        struct kevent ch[2], kev[4];
        struct timespec ts, *timeout = NULL;
        struct request *req;
        struct dnscache *ent;
        struct passwd *pwd;
-       int i, r, af;
+       int i, r;
        pid_t parent, child;
 
        parent = getpid();
@@ -476,9 +676,9 @@ launch(FILE *conf, int ud, int ld, int k
        if (pledge("stdio inet", NULL) == -1)
                logerr("pledge failed");
 
-       af = readconfig(conf, &remoteaddr);
+       r = readconfig(conf);
        fclose(conf);
-       if (af == -1)
+       if (r == -1)
                logerr("parse error in config file");
 
        EV_SET(&kev[0], ud, EVFILT_READ, EV_ADD, 0, 0, NULL);
@@ -517,37 +717,26 @@ launch(FILE *conf, int ud, int ld, int k
                        } else if (kev[i].filter == EVFILT_PROC) {
                                logmsg(LOG_INFO, "parent died");
                                exit(0);
-                       } else if (kev[i].filter == EVFILT_WRITE) {
-                               req = kev[i].udata;
-                               req = tcpphasetwo(req);
-                               if (req) {
-                                       EV_SET(&ch[0], req->s, EVFILT_WRITE,
-                                           EV_DELETE, 0, 0, NULL);
-                                       EV_SET(&ch[1], req->s, EVFILT_READ,
-                                           EV_ADD, 0, 0, req);
-                                       kevent(kq, ch, 2, NULL, 0, NULL);
-                               }
                        } else if (kev[i].filter != EVFILT_READ) {
                                logerr("don't know what happened");
                        } else if (kev[i].ident == ud) {
-                               if ((req = newrequest(ud,
-                                   (struct sockaddr *)&remoteaddr))) {
+                               if ((req = newrequest(ud))) {
                                        EV_SET(&ch[0], req->s, EVFILT_READ,
                                            EV_ADD, 0, 0, req);
                                        kevent(kq, ch, 1, NULL, 0, NULL);
                                }
                        } else if (kev[i].ident == ld) {
-                               if ((req = newtcprequest(ld,
-                                   (struct sockaddr *)&remoteaddr))) {
-                                       EV_SET(&ch[0], req->s,
-                                           req->tcp == 1 ? EVFILT_WRITE :
-                                           EVFILT_READ, EV_ADD, 0, 0, req);
+                               if ((req = newtcprequest(ld))) {
+                                       EV_SET(&ch[0], req->s, EVFILT_READ,
+                                           EV_ADD, 0, 0, req);
                                        kevent(kq, ch, 1, NULL, 0, NULL);
                                }
                        } else {
                                req = kev[i].udata;
                                if (req->tcp == 0)
                                        sendreply(ud, req);
+                               else
+                                       sendtcpreply(req);
                                freerequest(req);
                        }
                }
@@ -655,6 +844,7 @@ main(int argc, char **argv)
 
        TAILQ_INIT(&reqfifo);
        TAILQ_INIT(&cachefifo);
+       TAILQ_INIT(&matches);
        RB_INIT(&cachetree);
 
        memset(&bindaddr, 0, sizeof(bindaddr));

Reply via email to