On Fri, Jan 08, 2021 at 06:44:41PM +0100, Florian Obser wrote:
> This is on top of the "unwind(8): respect DO flag" diff I just sent to
> tech.
> 
> This is a bit rough around the edges, but if you feel lucky...
> 
> Currently unwind(8) only accepts queries over udp. We can get away
> with that since we are only listening on localhost and localhost has a
> large mtu, but it's technically not correct and even over localhost we
> can't have answers larger than somewhere around 40 kbytes.

Updated diff, this is now also good to go.

Tests, OKs?

diff --git frontend.c frontend.c
index 0462f37a258..c9b30f618e5 100644
--- frontend.c
+++ frontend.c
@@ -74,6 +74,10 @@
 #define COMPRESSED_RR_SIZE     12
 #define MINIMIZE_ANSWER                1
 
+#define FD_RESERVE             5
+#define TCP_TIMEOUT            15
+#define DEFAULT_TCP_SIZE       512
+
 struct udp_ev {
        struct event             ev;
        uint8_t                  query[65536];
@@ -82,6 +86,11 @@ struct udp_ev {
        struct sockaddr_storage  from;
 } udp4ev, udp6ev;
 
+struct tcp_accept_ev {
+       struct event             ev;
+       struct event             pause;
+} tcp4ev, tcp6ev;
+
 struct pending_query {
        TAILQ_ENTRY(pending_query)       entry;
        struct sockaddr_storage          from;
@@ -91,8 +100,12 @@ struct pending_query {
        struct query_info                qinfo;
        struct msg_parse                *qmsg;
        struct edns_data                 edns;
+       struct event                     ev;            /* for tcp */
+       struct event                     resp_ev;       /* for tcp */
+       struct event                     tmo_ev;        /* for tcp */
        uint64_t                         imsg_id;
        int                              fd;
+       int                              tcp;
 };
 
 TAILQ_HEAD(, pending_query)     pending_queries;
@@ -108,6 +121,12 @@ void                        frontend_startup(void);
 void                    udp_receive(int, short, void *);
 void                    handle_query(struct pending_query *);
 void                    free_pending_query(struct pending_query *);
+void                    tcp_accept(int, short, void *);
+int                     accept_reserve(int, struct sockaddr *, socklen_t *);
+void                    accept_paused(int, short, void *);
+void                    tcp_request(int, short, void *);
+void                    tcp_response(int, short, void *);
+void                    tcp_timeout(int, short, void *);
 int                     check_query(sldns_buffer*);
 void                    noerror_answer(struct pending_query *);
 void                    chaos_answer(struct pending_query *);
@@ -133,6 +152,7 @@ struct imsgev               *iev_main;
 struct imsgev          *iev_resolver;
 struct event            ev_route;
 int                     udp4sock = -1, udp6sock = -1, routesock = -1;
+int                     tcp4sock = -1, tcp6sock = -1;
 int                     ta_fd = -1;
 
 static struct trust_anchor_head         trust_anchors, new_trust_anchors;
@@ -371,6 +391,30 @@ frontend_dispatch_main(int fd, short event, void *bula)
                            udp_receive, &udp4ev);
                        event_add(&udp4ev.ev, NULL);
                        break;
+               case IMSG_TCP4SOCK:
+                       if (tcp4sock != -1)
+                               fatalx("%s: received unexpected tcp4sock",
+                                   __func__);
+                       if ((tcp4sock = imsg.fd) == -1)
+                               fatalx("%s: expected to receive imsg "
+                                   "TCP4 fd but didn't receive any", __func__);
+                       event_set(&tcp4ev.ev, tcp4sock, EV_READ | EV_PERSIST,
+                           tcp_accept, &tcp4ev);
+                       event_add(&tcp4ev.ev, NULL);
+                       evtimer_set(&tcp4ev.pause, accept_paused, &tcp4ev);
+                       break;
+               case IMSG_TCP6SOCK:
+                       if (tcp6sock != -1)
+                               fatalx("%s: received unexpected tcp6sock",
+                                   __func__);
+                       if ((tcp6sock = imsg.fd) == -1)
+                               fatalx("%s: expected to receive imsg "
+                                   "TCP6 fd but didn't receive any", __func__);
+                       event_set(&tcp6ev.ev, tcp6sock, EV_READ | EV_PERSIST,
+                           tcp_accept, &tcp6ev);
+                       event_add(&tcp6ev.ev, NULL);
+                       evtimer_set(&tcp6ev.pause, accept_paused, &tcp6ev);
+                       break;
                case IMSG_ROUTESOCK:
                        if (routesock != -1)
                                fatalx("%s: received unexpected routesock",
@@ -577,6 +621,16 @@ free_pending_query(struct pending_query *pq)
        regional_destroy(pq->region);
        sldns_buffer_free(pq->qbuf);
        sldns_buffer_free(pq->abuf);
+       if (pq->tcp) {
+               if (event_initialized(&pq->ev))
+                       event_del(&pq->ev);
+               if (event_initialized(&pq->resp_ev))
+                       event_del(&pq->resp_ev);
+               if (event_initialized(&pq->tmo_ev))
+                       event_del(&pq->tmo_ev);
+               if (pq->fd != -1)
+                       close(pq->fd);
+       }
        free(pq);
 }
 
@@ -763,7 +817,7 @@ noerror_answer(struct pending_query *pq)
 
        sldns_buffer_clear(pq->abuf);
        if (reply_info_encode(&pq->qinfo, rinfo, pq->qmsg->id, rinfo->flags,
-           pq->abuf, 0, pq->region, UINT16_MAX, /* XXX pq->edns.udp_size, */
+           pq->abuf, 0, pq->region, pq->tcp ? UINT16_MAX : pq->edns.udp_size,
            pq->edns.bits & EDNS_DO, MINIMIZE_ANSWER) == 0)
                goto srvfail;
 
@@ -883,12 +937,30 @@ send_answer(struct pending_query *pq)
                free(str);
        }
 
-       if(sendto(pq->fd, sldns_buffer_current(pq->abuf),
-           sldns_buffer_remaining(pq->abuf), 0,
-           (struct sockaddr *)&pq->from, pq->from.ss_len) == -1)
-               log_warn("sendto");
+       if (!pq->tcp) {
+               if(sendto(pq->fd, sldns_buffer_current(pq->abuf),
+                   sldns_buffer_remaining(pq->abuf), 0,
+                   (struct sockaddr *)&pq->from, pq->from.ss_len) == -1)
+                       log_warn("sendto");
+               free_pending_query(pq);
+       } else {
+               struct sldns_buffer     *tmp;
 
-       free_pending_query(pq);
+               tmp = sldns_buffer_new(sldns_buffer_limit(pq->abuf) + 2);
+
+               if (!tmp) {
+                       free_pending_query(pq);
+                       return;
+               }
+
+               sldns_buffer_write_u16(tmp, sldns_buffer_limit(pq->abuf));
+               sldns_buffer_write(tmp, sldns_buffer_current(pq->abuf),
+                   sldns_buffer_remaining(pq->abuf));
+               sldns_buffer_flip(tmp);
+               sldns_buffer_free(pq->abuf);
+               pq->abuf = tmp;
+               event_add(&pq->resp_ev, NULL);
+       }
 }
 
 char*
@@ -1268,3 +1340,170 @@ pending_query_cnt(void)
                cnt++;
        return cnt;
 }
+
+void
+accept_paused(int fd, short events, void *arg)
+{
+       struct tcp_accept_ev    *tcpev = arg;
+       event_add(&tcpev->ev, NULL);
+}
+
+int
+accept_reserve(int sockfd, struct sockaddr *addr, socklen_t *addrlen)
+{
+       if (getdtablecount() + FD_RESERVE >= getdtablesize()) {
+               log_debug("%s: inflight fds exceeded", __func__);
+               errno = EMFILE;
+               return -1;
+       }
+       return accept4(sockfd, addr, addrlen, SOCK_NONBLOCK | SOCK_CLOEXEC);
+}
+
+void
+tcp_accept(int fd, short events, void *arg)
+{
+       static struct timeval    timeout = {TCP_TIMEOUT, 0};
+       static struct timeval    backoff = {1, 0};
+       struct pending_query    *pq;
+       struct tcp_accept_ev    *tcpev;
+       struct sockaddr_storage  ss;
+       socklen_t                len;
+       int                      s;
+
+       tcpev = arg;
+       len = sizeof(ss);
+
+       if ((s = accept_reserve(fd, (struct sockaddr *)&ss, &len)) == -1) {
+               switch (errno) {
+               case EINTR:
+               case EWOULDBLOCK:
+               case ECONNABORTED:
+                       return;
+               case EMFILE:
+               case ENFILE:
+                       event_del(&tcpev->ev);
+                       evtimer_add(&tcpev->pause, &backoff);
+                       return;
+               default:
+                       fatal("accept");
+               }
+       }
+
+       if ((pq = calloc(1, sizeof(*pq))) == NULL) {
+               log_warn(NULL);
+               close(s);
+               return;
+       }
+
+       do {
+               arc4random_buf(&pq->imsg_id, sizeof(pq->imsg_id));
+       } while(find_pending_query(pq->imsg_id) != NULL);
+
+       TAILQ_INSERT_TAIL(&pending_queries, pq, entry);
+
+       pq->from = ss;
+       pq->fd = s;
+       pq->tcp = 1;
+       pq->qbuf = sldns_buffer_new(DEFAULT_TCP_SIZE);
+       pq->region = regional_create();
+       pq->qmsg = regional_alloc(pq->region, sizeof(*pq->qmsg));
+
+       if (!pq->qbuf || !pq->region || !pq->qmsg) {
+               free_pending_query(pq);
+               return;
+       }
+
+       memset(pq->qmsg, 0, sizeof(*pq->qmsg));
+
+       event_set(&pq->ev, s, EV_READ | EV_PERSIST, tcp_request, pq);
+       event_add(&pq->ev, NULL);
+       event_set(&pq->resp_ev, s, EV_WRITE | EV_PERSIST, tcp_response, pq);
+
+       evtimer_set(&pq->tmo_ev, tcp_timeout, pq);
+       evtimer_add(&pq->tmo_ev, &timeout);
+}
+
+void
+tcp_request(int fd, short events, void *arg)
+{
+       struct pending_query    *pq;
+       ssize_t                  n;
+
+       pq = arg;
+
+       n = read(fd, sldns_buffer_current(pq->qbuf),
+           sldns_buffer_remaining(pq->qbuf));
+
+       switch (n) {
+       case -1:
+               switch (errno) {
+               case EINTR:
+               case EAGAIN:
+                       return;
+               default:
+                       goto fail;
+               }
+               break;
+       case 0:
+               log_debug("closed connection");
+               goto fail;
+       default:
+               break;
+       }
+
+       sldns_buffer_skip(pq->qbuf, n);
+
+       if (sldns_buffer_position(pq->qbuf) >= 2 && !pq->abuf) {
+               struct sldns_buffer     *tmp;
+               uint16_t                 len;
+
+               sldns_buffer_flip(pq->qbuf);
+               len = sldns_buffer_read_u16(pq->qbuf);
+               tmp = sldns_buffer_new(len);
+               pq->abuf = sldns_buffer_new(len);
+
+               if (!tmp || !pq->abuf)
+                       goto fail;
+
+               sldns_buffer_write(tmp, sldns_buffer_current(pq->qbuf),
+                   sldns_buffer_remaining(pq->qbuf));
+               sldns_buffer_free(pq->qbuf);
+               pq->qbuf = tmp;
+       }
+       if (sldns_buffer_remaining(pq->qbuf) == 0) {
+               sldns_buffer_flip(pq->qbuf);
+               shutdown(fd, SHUT_RD);
+               event_del(&pq->ev);
+               handle_query(pq);
+       }
+       return;
+fail:
+       free_pending_query(pq);
+}
+
+void
+tcp_response(int fd, short events, void *arg)
+{
+       struct pending_query    *pq;
+       ssize_t                  n;
+
+       pq = arg;
+
+       n = write(fd, sldns_buffer_current(pq->abuf),
+           sldns_buffer_remaining(pq->abuf));
+
+       if (n == -1) {
+               if (errno == EAGAIN || errno == EINTR)
+                       return;
+               free_pending_query(pq);
+       }
+       sldns_buffer_skip(pq->abuf, n);
+       if (sldns_buffer_remaining(pq->abuf) == 0)
+               free_pending_query(pq);
+}
+
+void
+tcp_timeout(int fd, short events, void *arg)
+{
+       free_pending_query(arg);
+}
diff --git unwind.c unwind.c
index cde7c2d0dc8..69f63428c60 100644
--- unwind.c
+++ unwind.c
@@ -725,6 +725,7 @@ open_ports(void)
 {
        struct addrinfo  hints, *res0;
        int              udp4sock = -1, udp6sock = -1, error, bsize = 65535;
+       int              tcp4sock = -1, tcp6sock = -1;
        int              opt = 1;
 
        memset(&hints, 0, sizeof(hints));
@@ -773,13 +774,73 @@ open_ports(void)
        if (res0)
                freeaddrinfo(res0);
 
-       if (udp4sock == -1 && udp6sock == -1)
-               fatal("could not bind to 127.0.0.1 or ::1 on port 53");
+       hints.ai_family = AF_INET;
+       hints.ai_socktype = SOCK_STREAM;
+
+       error = getaddrinfo("127.0.0.1", "domain", &hints, &res0);
+       if (!error && res0) {
+               if ((tcp4sock = socket(res0->ai_family,
+                   res0->ai_socktype | SOCK_NONBLOCK,
+                   res0->ai_protocol)) != -1) {
+                       if (setsockopt(tcp4sock, SOL_SOCKET, SO_REUSEADDR,
+                           &opt, sizeof(opt)) == -1)
+                               log_warn("setting SO_REUSEADDR on socket");
+                       if (setsockopt(tcp4sock, SOL_SOCKET, SO_SNDBUF, &bsize,
+                           sizeof(bsize)) == -1)
+                               log_warn("setting SO_SNDBUF on socket");
+                       if (bind(tcp4sock, res0->ai_addr, res0->ai_addrlen)
+                           == -1) {
+                               close(tcp4sock);
+                               tcp4sock = -1;
+                       }
+                       if (listen(tcp4sock, 5) == -1) {
+                               close(tcp4sock);
+                               tcp4sock = -1;
+                       }
+               }
+       }
+       if (res0)
+               freeaddrinfo(res0);
+
+       hints.ai_family = AF_INET6;
+       error = getaddrinfo("::1", "domain", &hints, &res0);
+       if (!error && res0) {
+               if ((tcp6sock = socket(res0->ai_family,
+                   res0->ai_socktype | SOCK_NONBLOCK,
+                   res0->ai_protocol)) != -1) {
+                       if (setsockopt(tcp6sock, SOL_SOCKET, SO_REUSEADDR,
+                           &opt, sizeof(opt)) == -1)
+                               log_warn("setting SO_REUSEADDR on socket");
+                       if (setsockopt(tcp6sock, SOL_SOCKET, SO_SNDBUF, &bsize,
+                           sizeof(bsize)) == -1)
+                               log_warn("setting SO_SNDBUF on socket");
+                       if (bind(tcp6sock, res0->ai_addr, res0->ai_addrlen)
+                           == -1) {
+                               close(tcp6sock);
+                               tcp6sock = -1;
+                       }
+                       if (listen(tcp6sock, 5) == -1) {
+                               close(tcp6sock);
+                               tcp6sock = -1;
+                       }
+               }
+       }
+       if (res0)
+               freeaddrinfo(res0);
+
+       if ((udp4sock == -1 || tcp4sock == -1) && (udp6sock == -1 ||
+           tcp6sock == -1))
+               //fatalx("could not bind to 127.0.0.1 or ::1 on port 53");
+               fatalx("could not bind to 127.0.0.1 or ::1 on port 53 %d %d %d 
%d", udp4sock, tcp4sock, udp6sock, tcp6sock);
 
        if (udp4sock != -1)
                main_imsg_compose_frontend_fd(IMSG_UDP4SOCK, 0, udp4sock);
        if (udp6sock != -1)
                main_imsg_compose_frontend_fd(IMSG_UDP6SOCK, 0, udp6sock);
+       if (tcp4sock != -1)
+               main_imsg_compose_frontend_fd(IMSG_TCP4SOCK, 0, tcp4sock);
+       if (tcp6sock != -1)
+               main_imsg_compose_frontend_fd(IMSG_TCP6SOCK, 0, tcp6sock);
 }
 
 void
diff --git unwind.h unwind.h
index 659d94639e9..b2c6d378836 100644
--- unwind.h
+++ unwind.h
@@ -109,6 +109,8 @@ enum imsg_type {
        IMSG_RECONF_END,
        IMSG_UDP4SOCK,
        IMSG_UDP6SOCK,
+       IMSG_TCP4SOCK,
+       IMSG_TCP6SOCK,
        IMSG_ROUTESOCK,
        IMSG_CONTROLFD,
        IMSG_STARTUP,


-- 
I'm not entirely sure you are real.

Reply via email to