On Fri, Jan 08, 2021 at 06:37:35PM +0100, Florian Obser wrote:
> Rewrite query parsing and answer formating using libunbound provided
> functions.
> With this we can filter out DNSSEC RRsets if the client did not ask
> for them.
> We will also be able to send truncated answers to indicate to the
> client to switch to tcp (disabled for now since we have no tcp support
> yet).
> 
> Tests, OKs?
> 

I've recut this diff & the tcp diff.

This fixes another memory leak with the usage of reply_info_parse()
and contains some minor code shuffling that came up will workin on
tcp. E.g. free_pending_query now unhooks the query from the tailq so
we must ensure to add the struct as soon as we allocate it
sucessfully.

Tests, OKs?

diff --git frontend.c frontend.c
index 9f91396c4c3..0462f37a258 100644
--- frontend.c
+++ frontend.c
@@ -48,7 +48,11 @@
 #include "libunbound/sldns/sbuffer.h"
 #include "libunbound/sldns/str2wire.h"
 #include "libunbound/sldns/wire2str.h"
+#include "libunbound/util/alloc.h"
+#include "libunbound/util/net_help.h"
+#include "libunbound/util/regional.h"
 #include "libunbound/util/data/dname.h"
+#include "libunbound/util/data/msgencode.h"
 #include "libunbound/util/data/msgparse.h"
 #include "libunbound/util/data/msgreply.h"
 
@@ -68,6 +72,7 @@
  * 2 octets RDLENGTH
  */
 #define COMPRESSED_RR_SIZE     12
+#define MINIMIZE_ANSWER                1
 
 struct udp_ev {
        struct event             ev;
@@ -81,14 +86,13 @@ struct pending_query {
        TAILQ_ENTRY(pending_query)       entry;
        struct sockaddr_storage          from;
        struct sldns_buffer             *qbuf;
-       ssize_t                          len;
+       struct sldns_buffer             *abuf;
+       struct regional                 *region;
+       struct query_info                qinfo;
+       struct msg_parse                *qmsg;
+       struct edns_data                 edns;
        uint64_t                         imsg_id;
        int                              fd;
-       int                              bogus;
-       int                              rcode_override;
-       int                              answer_len;
-       int                              received;
-       uint8_t                         *answer;
 };
 
 TAILQ_HEAD(, pending_query)     pending_queries;
@@ -102,8 +106,12 @@ __dead void                 frontend_shutdown(void);
 void                    frontend_sig_handler(int, short, void *);
 void                    frontend_startup(void);
 void                    udp_receive(int, short, void *);
+void                    handle_query(struct pending_query *);
+void                    free_pending_query(struct pending_query *);
 int                     check_query(sldns_buffer*);
+void                    noerror_answer(struct pending_query *);
 void                    chaos_answer(struct pending_query *);
+void                    error_answer(struct pending_query *, int rcode);
 void                    send_answer(struct pending_query *);
 void                    route_receive(int, short, void *);
 void                    handle_route_message(struct rt_msghdr *,
@@ -471,37 +479,38 @@ frontend_dispatch_resolver(int fd, short event, void 
*bula)
                        }
 
                        if (answer_header->srvfail) {
-                               free(pq->answer);
-                               pq->answer_len = 0;
-                               pq->answer = NULL;
-                               pq->rcode_override = LDNS_RCODE_SERVFAIL;
+                               error_answer(pq, LDNS_RCODE_SERVFAIL);
                                send_answer(pq);
                                break;
                        }
 
-                       if (pq->answer == NULL) {
-                               pq->answer = malloc(answer_header->answer_len);
-                               if (pq->answer == NULL) {
-                                       pq->answer_len = 0;
-                                       pq->rcode_override =
-                                           LDNS_RCODE_SERVFAIL;
-                                       send_answer(pq);
-                                       break;
-                               }
-                               pq->answer_len = answer_header->answer_len;
-                               pq->received = 0;
-                               pq->bogus = answer_header->bogus;
+                       if (answer_header->bogus && !(pq->qmsg->flags &
+                           BIT_CD)) {
+                               error_answer(pq, LDNS_RCODE_SERVFAIL);
+                               send_answer(pq);
+                               break;
+                       }
+
+                       if (sldns_buffer_position(pq->abuf) == 0 &&
+                           !sldns_buffer_set_capacity(pq->abuf,
+                           answer_header->answer_len)) {
+                               error_answer(pq, LDNS_RCODE_SERVFAIL);
+                               send_answer(pq);
+                               break;
                        }
 
-                       if (pq->received + data_len > pq->answer_len)
+                       if (sldns_buffer_position(pq->abuf) + data_len >
+                           sldns_buffer_capacity(pq->abuf))
                                fatalx("%s: IMSG_ANSWER answer too big: %d",
                                    __func__, data_len);
+                       sldns_buffer_write(pq->abuf, data, data_len);
 
-                       memcpy(pq->answer + pq->received, data, data_len);
-                       pq->received += data_len;
-
-                       if (pq->received == pq->answer_len)
+                       if (sldns_buffer_position(pq->abuf) ==
+                           sldns_buffer_capacity(pq->abuf)) {
+                               sldns_buffer_flip(pq->abuf);
+                               noerror_answer(pq);
                                send_answer(pq);
+                       }
                        break;
                }
                case IMSG_CTL_RESOLVER_INFO:
@@ -558,22 +567,25 @@ frontend_startup(void)
        frontend_imsg_compose_main(IMSG_STARTUP_DONE, 0, NULL, 0);
 }
 
+void
+free_pending_query(struct pending_query *pq)
+{
+       if (!pq)
+               return;
+
+       TAILQ_REMOVE(&pending_queries, pq, entry);
+       regional_destroy(pq->region);
+       sldns_buffer_free(pq->qbuf);
+       sldns_buffer_free(pq->abuf);
+       free(pq);
+}
+
 void
 udp_receive(int fd, short events, void *arg)
 {
        struct udp_ev           *udpev = (struct udp_ev *)arg;
-       struct pending_query    *pq;
-       struct query_imsg        query_imsg;
-       struct query_info        qinfo;
-       struct bl_node           find;
-       ssize_t                  len, dname_len;
-       int                      ret;
-       char                    *str;
-       char                     dname[LDNS_MAX_DOMAINLEN + 1];
-       char                     qclass_buf[16];
-       char                     qtype_buf[16];
-
-       memset(&qinfo, 0, sizeof(qinfo));
+       struct pending_query    *pq = NULL;
+       ssize_t                  len;
 
        if ((len = recvmsg(fd, &udpev->rcvmhdr, 0)) == -1) {
                log_warn("recvmsg");
@@ -585,161 +597,236 @@ udp_receive(int fd, short events, void *arg)
                return;
        }
 
-       pq->rcode_override = LDNS_RCODE_NOERROR;
-       pq->len = len;
-       pq->from = udpev->from;
-       pq->fd = fd;
-
        do {
                arc4random_buf(&pq->imsg_id, sizeof(pq->imsg_id));
        } while(find_pending_query(pq->imsg_id) != NULL);
 
-       if ((pq->qbuf = sldns_buffer_new(len)) == NULL) {
-               log_warnx("sldns_buffer_new");
-               goto drop;
+       TAILQ_INSERT_TAIL(&pending_queries, pq, entry);
+
+       pq->from = udpev->from;
+       pq->fd = fd;
+       pq->qbuf = sldns_buffer_new(len);
+       pq->abuf = sldns_buffer_new(len); /* make sure we can send errors */
+       pq->region = regional_create();
+       pq->qmsg = regional_alloc(pq->region, sizeof(*pq->qmsg));
+
+       if (!pq->qbuf || !pq->abuf || !pq->region || !pq->qmsg) {
+               log_warnx("out of memory");
+               free_pending_query(pq);
+               return;
        }
-       sldns_buffer_clear(pq->qbuf);
+
+       memset(pq->qmsg, 0, sizeof(*pq->qmsg));
        sldns_buffer_write(pq->qbuf, udpev->query, len);
        sldns_buffer_flip(pq->qbuf);
+       handle_query(pq);
+}
+
+void
+handle_query(struct pending_query *pq)
+{
+       struct query_imsg        query_imsg;
+       struct bl_node           find;
+       int                      rcode;
+       char                    *str;
+       char                     dname[LDNS_MAX_DOMAINLEN + 1];
+       char                     qclass_buf[16];
+       char                     qtype_buf[16];
 
        if (log_getverbose() & OPT_VERBOSE2 && (str =
-           sldns_wire2str_pkt(udpev->query, len)) != NULL) {
+           sldns_wire2str_pkt(sldns_buffer_begin(pq->qbuf),
+           sldns_buffer_limit(pq->qbuf))) != NULL) {
                log_debug("from: %s\n%s", ip_port((struct sockaddr *)
-                   &udpev->from), str);
+                   &pq->from), str);
                free(str);
        }
 
-       if ((ret = check_query(pq->qbuf)) != LDNS_RCODE_NOERROR) {
-               if (ret == -1)
-                       goto drop;
-               else
-                       pq->rcode_override = ret;
+       if (!query_info_parse(&pq->qinfo, pq->qbuf)) {
+               log_warnx("query_info_parse failed");
+               goto drop;
+       }
+
+       sldns_buffer_rewind(pq->qbuf);
+
+       if (parse_packet(pq->qbuf, pq->qmsg, pq->region) !=
+           LDNS_RCODE_NOERROR) {
+               log_warnx("parse_packet failed");
+               goto drop;
+       }
+
+       rcode = check_query(pq->qbuf);
+       switch (rcode) {
+       case LDNS_RCODE_NOERROR:
+               break;
+       case -1:
+               goto drop;
+       default:
+               error_answer(pq, rcode);
                goto send_answer;
        }
 
-       if (!query_info_parse(&qinfo, pq->qbuf)) {
-               pq->rcode_override = LDNS_RCODE_FORMERR;
+       rcode = parse_extract_edns(pq->qmsg, &pq->edns, pq->region);
+       if (rcode != LDNS_RCODE_NOERROR) {
+               error_answer(pq, rcode);
                goto send_answer;
        }
 
-       if ((dname_len = dname_valid(qinfo.qname, qinfo.qname_len)) == 0) {
-               pq->rcode_override = LDNS_RCODE_FORMERR;
+       if (!dname_valid(pq->qinfo.qname, pq->qinfo.qname_len)) {
+               error_answer(pq, LDNS_RCODE_FORMERR);
                goto send_answer;
        }
-       dname_str(qinfo.qname, dname);
+       dname_str(pq->qinfo.qname, dname);
 
-       sldns_wire2str_class_buf(qinfo.qclass, qclass_buf, sizeof(qclass_buf));
-       sldns_wire2str_type_buf(qinfo.qtype, qtype_buf, sizeof(qtype_buf));
-       log_debug("%s: %s %s %s ?", ip_port((struct sockaddr *)&udpev->from),
+       sldns_wire2str_class_buf(pq->qinfo.qclass, qclass_buf,
+           sizeof(qclass_buf));
+       sldns_wire2str_type_buf(pq->qinfo.qtype, qtype_buf, sizeof(qtype_buf));
+       log_debug("%s: %s %s %s ?", ip_port((struct sockaddr *)&pq->from),
            dname, qclass_buf, qtype_buf);
 
        find.domain = dname;
        if (RB_FIND(bl_tree, &bl_head, &find) != NULL) {
                if (frontend_conf->blocklist_log)
                        log_info("blocking %s", dname);
-               pq->rcode_override = LDNS_RCODE_REFUSED;
+               error_answer(pq, LDNS_RCODE_REFUSED);
                goto send_answer;
        }
 
-       if (qinfo.qtype == LDNS_RR_TYPE_AXFR || qinfo.qtype ==
+       if (pq->qinfo.qtype == LDNS_RR_TYPE_AXFR || pq->qinfo.qtype ==
            LDNS_RR_TYPE_IXFR) {
-               pq->rcode_override = LDNS_RCODE_REFUSED;
+               error_answer(pq, LDNS_RCODE_REFUSED);
                goto send_answer;
        }
 
-       if(qinfo.qtype == LDNS_RR_TYPE_OPT ||
-           qinfo.qtype == LDNS_RR_TYPE_TSIG ||
-           qinfo.qtype == LDNS_RR_TYPE_TKEY ||
-           qinfo.qtype == LDNS_RR_TYPE_MAILA ||
-           qinfo.qtype == LDNS_RR_TYPE_MAILB ||
-           (qinfo.qtype >= 128 && qinfo.qtype <= 248)) {
-               pq->rcode_override = LDNS_RCODE_FORMERR;
+       if(pq->qinfo.qtype == LDNS_RR_TYPE_OPT ||
+           pq->qinfo.qtype == LDNS_RR_TYPE_TSIG ||
+           pq->qinfo.qtype == LDNS_RR_TYPE_TKEY ||
+           pq->qinfo.qtype == LDNS_RR_TYPE_MAILA ||
+           pq->qinfo.qtype == LDNS_RR_TYPE_MAILB ||
+           (pq->qinfo.qtype >= 128 && pq->qinfo.qtype <= 248)) {
+               error_answer(pq, LDNS_RCODE_FORMERR);
                goto send_answer;
        }
 
-       if (qinfo.qclass == LDNS_RR_CLASS_CH) {
+       if (pq->qinfo.qclass == LDNS_RR_CLASS_CH) {
                if (strcasecmp(dname, "version.server.") == 0 ||
                    strcasecmp(dname, "version.bind.") == 0) {
                        chaos_answer(pq);
                } else
-                       pq->rcode_override = LDNS_RCODE_REFUSED;
+                       error_answer(pq, LDNS_RCODE_REFUSED);
                goto send_answer;
        }
 
        if (strlcpy(query_imsg.qname, dname, sizeof(query_imsg.qname)) >=
            sizeof(query_imsg.qname)) {
                log_warnx("qname too long");
-               pq->rcode_override = LDNS_RCODE_FORMERR;
+               error_answer(pq, LDNS_RCODE_FORMERR);
                goto send_answer;
        }
        query_imsg.id = pq->imsg_id;
-       query_imsg.t = qinfo.qtype;
-       query_imsg.c = qinfo.qclass;
+       query_imsg.t = pq->qinfo.qtype;
+       query_imsg.c = pq->qinfo.qclass;
 
        if (frontend_imsg_compose_resolver(IMSG_QUERY, 0, &query_imsg,
-           sizeof(query_imsg)) != -1)
-               TAILQ_INSERT_TAIL(&pending_queries, pq, entry);
-       else {
-               pq->rcode_override = LDNS_RCODE_SERVFAIL;
+           sizeof(query_imsg)) == -1) {
+               error_answer(pq, LDNS_RCODE_SERVFAIL);
                goto send_answer;
        }
        return;
 
  send_answer:
-       TAILQ_INSERT_TAIL(&pending_queries, pq, entry);
        send_answer(pq);
-       pq = NULL;
+       return;
+
  drop:
-       if (pq != NULL)
-               sldns_buffer_free(pq->qbuf);
-       free(pq);
+       free_pending_query(pq);
+}
+
+void
+noerror_answer(struct pending_query *pq)
+{
+       struct query_info        skip, qinfo;
+       struct reply_info       *rinfo = NULL;
+       struct alloc_cache       alloc;
+       struct edns_data         edns;
+
+       alloc_init(&alloc, NULL, 0);
+       memset(&qinfo, 0, sizeof(qinfo));
+       /* read past query section, no memory is allocated */
+       if (!query_info_parse(&skip, pq->abuf))
+               goto srvfail;
+
+       if (reply_info_parse(pq->abuf, &alloc, &qinfo, &rinfo, pq->region,
+           &edns) != 0)
+               goto srvfail;
+       /* reply_info_parse() allocates memory */
+       query_info_clear(&qinfo);
+
+       sldns_buffer_clear(pq->abuf);
+       if (reply_info_encode(&pq->qinfo, rinfo, pq->qmsg->id, rinfo->flags,
+           pq->abuf, 0, pq->region, UINT16_MAX, /* XXX pq->edns.udp_size, */
+           pq->edns.bits & EDNS_DO, MINIMIZE_ANSWER) == 0)
+               goto srvfail;
+
+       reply_info_parsedelete(rinfo, &alloc);
+       alloc_clear(&alloc);
+       return;
+
+ srvfail:
+       reply_info_parsedelete(rinfo, &alloc);
+       alloc_clear(&alloc);
+       error_answer(pq, LDNS_RCODE_SERVFAIL);
 }
 
 void
 chaos_answer(struct pending_query *pq)
 {
-       struct sldns_buffer      buf, *pkt = &buf;
-       size_t                   size, len;
-       char                    *name = "unwind";
+       size_t           len;
+       const char      *name = "unwind";
 
        len = strlen(name);
-       size = sldns_buffer_capacity(pq->qbuf) + COMPRESSED_RR_SIZE + 1 + len;
-
-       if (pq->answer != 0)
-               fatal("chaos_answer");
-       if ((pq->answer = calloc(1, size)) == NULL)
+       if (!sldns_buffer_set_capacity(pq->abuf,
+           sldns_buffer_capacity(pq->qbuf) + COMPRESSED_RR_SIZE + 1 + len)) {
+               error_answer(pq, LDNS_RCODE_SERVFAIL);
                return;
-       pq->answer_len = size;
-       memcpy(pq->answer, sldns_buffer_begin(pq->qbuf),
-           sldns_buffer_capacity(pq->qbuf));
-       sldns_buffer_init_frm_data(pkt, pq->answer, size);
+       }
+
+       sldns_buffer_copy(pq->abuf, pq->qbuf);
 
-       sldns_buffer_clear(pkt);
+       sldns_buffer_clear(pq->abuf);
 
-       sldns_buffer_skip(pkt, sizeof(uint16_t));       /* skip id */
-       sldns_buffer_write_u16(pkt, 0);                 /* clear flags */
-       LDNS_QR_SET(sldns_buffer_begin(pkt));
-       LDNS_RA_SET(sldns_buffer_begin(pkt));
+       sldns_buffer_skip(pq->abuf, sizeof(uint16_t));  /* skip id */
+       sldns_buffer_write_u16(pq->abuf, 0);            /* clear flags */
+       LDNS_QR_SET(sldns_buffer_begin(pq->abuf));
+       LDNS_RA_SET(sldns_buffer_begin(pq->abuf));
        if (LDNS_RD_WIRE(sldns_buffer_begin(pq->qbuf)))
-               LDNS_RD_SET(sldns_buffer_begin(pkt));
+               LDNS_RD_SET(sldns_buffer_begin(pq->abuf));
        if (LDNS_CD_WIRE(sldns_buffer_begin(pq->qbuf)))
-               LDNS_CD_SET(sldns_buffer_begin(pkt));
-       LDNS_RCODE_SET(sldns_buffer_begin(pkt), LDNS_RCODE_NOERROR);
-       sldns_buffer_write_u16(pkt, 1);                 /* qdcount */
-       sldns_buffer_write_u16(pkt, 1);                 /* ancount */
-       sldns_buffer_write_u16(pkt, 0);                 /* nscount */
-       sldns_buffer_write_u16(pkt, 0);                 /* arcount */
-       (void)query_dname_len(pkt);                     /* skip qname */
-       sldns_buffer_skip(pkt, sizeof(uint16_t));       /* skip qtype */
-       sldns_buffer_skip(pkt, sizeof(uint16_t));       /* skip qclass */
-
-       sldns_buffer_write_u16(pkt, 0xc00c);            /* ptr to query */
-       sldns_buffer_write_u16(pkt, LDNS_RR_TYPE_TXT);
-       sldns_buffer_write_u16(pkt, LDNS_RR_CLASS_CH);
-       sldns_buffer_write_u32(pkt, 0);                 /* TTL */
-       sldns_buffer_write_u16(pkt, 1 + len);           /* RDLENGTH */
-       sldns_buffer_write_u8(pkt, len);                /* length octed */
-       sldns_buffer_write(pkt, name, len);
+               LDNS_CD_SET(sldns_buffer_begin(pq->abuf));
+       LDNS_RCODE_SET(sldns_buffer_begin(pq->abuf), LDNS_RCODE_NOERROR);
+       sldns_buffer_write_u16(pq->abuf, 1);            /* qdcount */
+       sldns_buffer_write_u16(pq->abuf, 1);            /* ancount */
+       sldns_buffer_write_u16(pq->abuf, 0);            /* nscount */
+       sldns_buffer_write_u16(pq->abuf, 0);            /* arcount */
+       (void)query_dname_len(pq->abuf);                /* skip qname */
+       sldns_buffer_skip(pq->abuf, sizeof(uint16_t));  /* skip qtype */
+       sldns_buffer_skip(pq->abuf, sizeof(uint16_t));  /* skip qclass */
+
+       sldns_buffer_write_u16(pq->abuf, 0xc00c);       /* ptr to query */
+       sldns_buffer_write_u16(pq->abuf, LDNS_RR_TYPE_TXT);
+       sldns_buffer_write_u16(pq->abuf, LDNS_RR_CLASS_CH);
+       sldns_buffer_write_u32(pq->abuf, 0);            /* TTL */
+       sldns_buffer_write_u16(pq->abuf, 1 + len);      /* RDLENGTH */
+       sldns_buffer_write_u8(pq->abuf, len);           /* length octed */
+       sldns_buffer_write(pq->abuf, name, len);
+       sldns_buffer_flip(pq->abuf);
+}
+
+void
+error_answer(struct pending_query *pq, int rcode)
+{
+       sldns_buffer_clear(pq->abuf);
+       error_encode(pq->abuf, rcode, &pq->qinfo, pq->qmsg->id,
+           pq->qmsg->flags, pq->edns.edns_present ? &pq->edns : NULL);
 }
 
 int
@@ -786,59 +873,22 @@ check_query(sldns_buffer* pkt)
 void
 send_answer(struct pending_query *pq)
 {
-       ssize_t  len;
        char    *str;
-       uint8_t *answer;
-
-       answer = pq->answer;
-       len = pq->answer_len;
-
-       if (answer == NULL) {
-               answer = sldns_buffer_begin(pq->qbuf);
-               len = pq->len;
-
-               LDNS_QR_SET(answer);
-               LDNS_RA_SET(answer);
-               if (pq->rcode_override != LDNS_RCODE_NOERROR)
-                       LDNS_RCODE_SET(answer, pq->rcode_override);
-               else
-                       LDNS_RCODE_SET(answer, LDNS_RCODE_SERVFAIL);
-       } else {
-               if (pq->bogus) {
-                       if(LDNS_CD_WIRE(sldns_buffer_begin(pq->qbuf))) {
-                               LDNS_ID_SET(answer, LDNS_ID_WIRE(
-                                   sldns_buffer_begin(pq->qbuf)));
-                               LDNS_CD_SET(answer);
-                       } else {
-                               answer = sldns_buffer_begin(pq->qbuf);
-                               len = pq->len;
-
-                               LDNS_QR_SET(answer);
-                               LDNS_RA_SET(answer);
-                               LDNS_RCODE_SET(answer, LDNS_RCODE_SERVFAIL);
-                       }
-               } else {
-                       LDNS_ID_SET(answer, LDNS_ID_WIRE(sldns_buffer_begin(
-                           pq->qbuf)));
-               }
-       }
 
        if (log_getverbose() & OPT_VERBOSE2 && (str =
-           sldns_wire2str_pkt(answer, len)) != NULL) {
-               log_debug("to: %s\n%s",
-                   ip_port((struct sockaddr *)&pq->from),str);
+           sldns_wire2str_pkt(sldns_buffer_begin(pq->abuf),
+           sldns_buffer_limit(pq->abuf))) != NULL) {
+               log_debug("from: %s\n%s", ip_port((struct sockaddr *)
+                   &pq->from), str);
                free(str);
-               log_debug("pending query count: %d", pending_query_cnt());
        }
 
-       if(sendto(pq->fd, answer, len, 0, (struct sockaddr *)&pq->from,
-           pq->from.ss_len) == -1)
+       if(sendto(pq->fd, sldns_buffer_current(pq->abuf),
+           sldns_buffer_remaining(pq->abuf), 0,
+           (struct sockaddr *)&pq->from, pq->from.ss_len) == -1)
                log_warn("sendto");
 
-       TAILQ_REMOVE(&pending_queries, pq, entry);
-       sldns_buffer_free(pq->qbuf);
-       free(pq->answer);
-       free(pq);
+       free_pending_query(pq);
 }
 
 char*


-- 
I'm not entirely sure you are real.

Reply via email to