Did I get this right? I'd appreciate it if someone could give this a
once over.

    Since resolve() switched to a callback mechanism all uw_resolver objects
    pass through resolve() and either asr_resolve_done() or
    ub_resolve_done().
    With that we can pull resolver_ref() and resolver_unref() into those
    functions to make the reference counting easier.
    Only check_resolver is special since it needs to refcount the to be
    checked resolver. But the resolver doing the actual work is
    automatically refcounted by resolve() and *_resolve_done().
    One last piece of the puzzle is to track the uw_resolver object in
    cb_data so that the *_resolve_done() functions have access to it.
    This also allowes us to remove the ad-hoc passing of the resolver in
    query_imsg. Since the callback functions all need access to the
    resolver that did the work we pass it in as first argument.

diff --git resolver.c resolver.c
index d1dce2dec71..2b7d81d29fc 100644
--- resolver.c
+++ resolver.c
@@ -92,14 +92,11 @@ struct uw_resolver {
        int64_t                  histogram[nitems(histogram_limits)];
 };
 
-struct check_resolver_data {
-       struct uw_resolver      *res;
-       struct uw_resolver      *check_res;
-};
-
 struct resolver_cb_data {
-       void    (*cb)(void *, int, void *, int, int, char *);
-       void    *data;
+       void                    (*cb)(struct uw_resolver *, void *, int, void *,
+                                   int, int, char *);
+       void                    *data;
+       struct uw_resolver      *res;
 };
 
 __dead void             resolver_shutdown(void);
@@ -108,9 +105,10 @@ void                        
resolver_dispatch_frontend(int, short, void *);
 void                    resolver_dispatch_captiveportal(int, short, void *);
 void                    resolver_dispatch_main(int, short, void *);
 int                     resolve(struct uw_resolver *, const char*, int, int,
-                            void*, void (*cb)(void *, int, void *, int, int,
-                            char *));
-void                    resolve_done(void *, int, void *, int, int, char *);
+                            void*, void (*cb)(struct uw_resolver *, void *,
+                            int, void *, int, int, char *));
+void                    resolve_done(struct uw_resolver *, void *, int, void *,
+                            int, int, char *);
 void                    ub_resolve_done(void *, int, void *, int, int, char *,
                             int);
 void                    asr_resolve_done(struct asr_result *, void *);
@@ -129,8 +127,8 @@ void                         set_forwarders_oppdot(struct 
uw_resolver *,
 void                    resolver_check_timo(int, short, void *);
 void                    resolver_free_timo(int, short, void *);
 void                    check_resolver(struct uw_resolver *);
-void                    check_resolver_done(void *, int, void *, int, int,
-                            char *);
+void                    check_resolver_done(struct uw_resolver *, void *, int,
+                            void *, int, int, char *);
 void                    schedule_recheck_all_resolvers(void);
 int                     check_forwarders_changed(struct uw_forwarder_head *,
                             struct uw_forwarder_head *);
@@ -154,8 +152,8 @@ int                  check_captive_portal_changed(struct 
uw_conf *,
                             struct uw_conf *);
 void                    trust_anchor_resolve(void);
 void                    trust_anchor_timo(int, short, void *);
-void                    trust_anchor_resolve_done(void *, int, void *, int,
-                            int, char *);
+void                    trust_anchor_resolve_done(struct uw_resolver *, void *,
+                            int, void *, int, int, char *);
 void                    add_autoconf_forwarders(struct imsg_rdns_proposal *);
 void                    rem_autoconf_forwarders(struct imsg_rdns_proposal *);
 struct uw_forwarder    *find_forwarder(struct uw_forwarder_head *,
@@ -480,14 +478,10 @@ resolver_dispatch_frontend(int fd, short event, void 
*bula)
                        log_debug("%s: choosing %s", __func__,
                            uw_resolver_type_str[res->type]);
 
-                       query_imsg->resolver = res;
-                       resolver_ref(res);
-
                        clock_gettime(CLOCK_MONOTONIC, &query_imsg->tp);
 
-                       if ((resolve(res, query_imsg->qname, query_imsg->t,
-                           query_imsg->c, query_imsg, resolve_done)) != 0)
-                               resolver_unref(res);
+                       resolve(res, query_imsg->qname, query_imsg->t,
+                           query_imsg->c, query_imsg, resolve_done);
                        break;
                case IMSG_FORWARDER:
                        /* make sure this is a string */
@@ -773,16 +767,20 @@ resolver_dispatch_main(int fd, short event, void *bula)
 
 int
 resolve(struct uw_resolver *res, const char* name, int rrtype, int rrclass,
-    void *mydata, void (*cb)(void *, int, void *, int, int, char *))
+    void *mydata, void (*cb)(struct uw_resolver *, void *, int, void *, int,
+    int, char *))
 {
        struct resolver_cb_data *cb_data = NULL;
        struct asr_query        *aq = NULL;
        int                      err;
 
+       resolver_ref(res);
+
        if ((cb_data = malloc(sizeof(*cb_data))) == NULL)
                goto err;
        cb_data->cb = cb;
        cb_data->data = mydata;
+       cb_data->res = res;
 
        switch(res->type) {
        case UW_RES_ASR:
@@ -816,15 +814,15 @@ resolve(struct uw_resolver *res, const char* name, int 
rrtype, int rrclass,
  err:
        free(cb_data);
        free(aq);
+       resolver_unref(res);
        return 1;
 }
 
 void
-resolve_done(void *arg, int rcode, void *answer_packet, int answer_len,
-    int sec, char *why_bogus)
+resolve_done(struct uw_resolver *res, void *arg, int rcode,
+    void *answer_packet, int answer_len, int sec, char *why_bogus)
 {
        struct query_imsg       *query_imsg;
-       struct uw_resolver      *res;
        struct timespec          tp, elapsed;
        int64_t                  ms;
        size_t                   i;
@@ -833,7 +831,6 @@ resolve_done(void *arg, int rcode, void *answer_packet, int 
answer_len,
        clock_gettime(CLOCK_MONOTONIC, &tp);
 
        query_imsg = (struct query_imsg *)arg;
-       res = (struct uw_resolver *)query_imsg->resolver;
 
        timespecsub(&tp, &query_imsg->tp, &elapsed);
 
@@ -886,7 +883,6 @@ resolve_done(void *arg, int rcode, void *answer_packet, int 
answer_len,
            answer_packet, answer_len);
 
        free(query_imsg);
-       resolver_unref(res);
        return;
 
 servfail:
@@ -894,7 +890,6 @@ servfail:
        resolver_imsg_compose_frontend(IMSG_ANSWER_HEADER, 0, query_imsg,
            sizeof(*query_imsg));
        free(query_imsg);
-       resolver_unref(res);
 }
 
 void
@@ -1259,95 +1254,83 @@ resolver_free_timo(int fd, short events, void *arg)
 }
 
 void
-check_resolver(struct uw_resolver *res)
+check_resolver(struct uw_resolver *resolver_to_check)
 {
-       struct uw_resolver              *check_res;
-       struct check_resolver_data      *data;
+       struct uw_resolver              *res;
 
        log_debug("%s: create_resolver", __func__);
-       if ((check_res = create_resolver(res->type, 0)) == NULL)
-               fatal("%s", __func__);
-       if ((data = malloc(sizeof(*data))) == NULL)
+       if ((res = create_resolver(resolver_to_check->type, 0)) == NULL)
                fatal("%s", __func__);
 
-       resolver_ref(check_res);
-       resolver_ref(res);
-       data->check_res = check_res;
-       data->res = res;
+       resolver_ref(resolver_to_check);
 
-       if (resolve(check_res, ".", LDNS_RR_TYPE_NS, LDNS_RR_CLASS_IN,
-           data, check_resolver_done) != 0) {
-               res->state = UNKNOWN;
-               resolver_unref(check_res);
-               resolver_unref(res);
-               res->check_tv.tv_sec = RESOLVER_CHECK_SEC;
-               evtimer_add(&res->check_ev, &res->check_tv);
+       if (resolve(res, ".", LDNS_RR_TYPE_NS, LDNS_RR_CLASS_IN,
+           resolver_to_check, check_resolver_done) != 0) {
+               resolver_to_check->state = UNKNOWN;
+               resolver_unref(resolver_to_check);
+               resolver_to_check->check_tv.tv_sec = RESOLVER_CHECK_SEC;
+               evtimer_add(&resolver_to_check->check_ev,
+                   &resolver_to_check->check_tv);
 
                log_debug("%s: evtimer_add: %lld - %s: %s", __func__,
-                   data->res->check_tv.tv_sec,
-                   uw_resolver_type_str[data->res->type],
-                   uw_resolver_state_str[data->res->state]);
+                   resolver_to_check->check_tv.tv_sec,
+                   uw_resolver_type_str[resolver_to_check->type],
+                   uw_resolver_state_str[resolver_to_check->state]);
        }
 
-       if (!(res->type == UW_RES_DHCP || res->type == UW_RES_FORWARDER))
+       if (!(resolver_to_check->type == UW_RES_DHCP ||
+           resolver_to_check->type == UW_RES_FORWARDER))
                return;
 
        log_debug("%s: create_resolver for oppdot", __func__);
-       if ((check_res = create_resolver(res->type, 1)) == NULL)
-               fatal("%s", __func__);
-       if ((data = malloc(sizeof(*data))) == NULL)
+       if ((res = create_resolver(resolver_to_check->type, 1)) == NULL)
                fatal("%s", __func__);
 
-       resolver_ref(check_res);
-       resolver_ref(res);
-       data->check_res = check_res;
-       data->res = res;
+       resolver_ref(resolver_to_check);
 
-       if (resolve(check_res, ".", LDNS_RR_TYPE_NS, LDNS_RR_CLASS_IN,
-           data, check_resolver_done) != 0) {
+       if (resolve(res, ".", LDNS_RR_TYPE_NS, LDNS_RR_CLASS_IN,
+           resolver_to_check, check_resolver_done) != 0) {
                log_debug("check oppdot failed");
                /* do not overwrite normal DNS state, it might work */
-               resolver_unref(check_res);
-               resolver_unref(res);
+               resolver_unref(resolver_to_check);
 
-               res->check_tv.tv_sec = RESOLVER_CHECK_SEC;
-               evtimer_add(&res->check_ev, &res->check_tv);
+               resolver_to_check->check_tv.tv_sec = RESOLVER_CHECK_SEC;
+               evtimer_add(&resolver_to_check->check_ev,
+                   &resolver_to_check->check_tv);
 
                log_debug("%s: evtimer_add: %lld - %s: %s", __func__,
-                   data->res->check_tv.tv_sec,
-                   uw_resolver_type_str[data->res->type],
-                   uw_resolver_state_str[data->res->state]);
+                   resolver_to_check->check_tv.tv_sec,
+                   uw_resolver_type_str[resolver_to_check->type],
+                   uw_resolver_state_str[resolver_to_check->state]);
        }
 }
 
 void
-check_resolver_done(void *arg, int rcode, void *answer_packet, int answer_len,
-    int sec, char *why_bogus)
+check_resolver_done(struct uw_resolver *res, void *arg, int rcode,
+    void *answer_packet, int answer_len, int sec, char *why_bogus)
 {
-       struct check_resolver_data      *data;
-       struct timeval                   tv = {0, 1};
-       enum uw_resolver_state           prev_state;
-       char                            *str;
-
-       data = (struct check_resolver_data *)arg;
+       struct uw_resolver      *checked_resolver = arg;
+       struct timeval           tv = {0, 1};
+       enum uw_resolver_state   prev_state;
+       char                    *str;
 
        log_debug("%s: %s rcode: %d", __func__,
-           uw_resolver_type_str[data->res->type], rcode);
+           uw_resolver_type_str[checked_resolver->type], rcode);
 
-       prev_state = data->res->state;
+       prev_state = checked_resolver->state;
 
        if (answer_len < LDNS_HEADER_SIZE) {
-               data->res->state = DEAD;
+               checked_resolver->state = DEAD;
                log_warnx("bad packet: too short");
                goto out;
        }
 
        if (rcode == LDNS_RCODE_SERVFAIL) {
-               if (data->check_res->oppdot == data->res->oppdot) {
-                       data->res->state = DEAD;
-                       if (data->res->oppdot) {
+               if (res->oppdot == checked_resolver->oppdot) {
+                       checked_resolver->state = DEAD;
+                       if (checked_resolver->oppdot) {
                                /* downgrade from opportunistic DoT */
-                               switch (data->res->type) {
+                               switch (checked_resolver->type) {
                                case UW_RES_DHCP:
                                        new_forwarders(0);
                                        break;
@@ -1362,9 +1345,9 @@ check_resolver_done(void *arg, int rcode, void 
*answer_packet, int answer_len,
                goto out;
        }
 
-       if (data->check_res->oppdot && !data->res->oppdot) {
+       if (res->oppdot && !checked_resolver->oppdot) {
                /* upgrade to opportunistic DoT */
-               switch (data->res->type) {
+               switch (checked_resolver->type) {
                case UW_RES_DHCP:
                        new_forwarders(1);
                        break;
@@ -1382,56 +1365,57 @@ check_resolver_done(void *arg, int rcode, void 
*answer_packet, int answer_len,
        }
 
        if (sec == SECURE) {
-               data->res->state = VALIDATING;
+               checked_resolver->state = VALIDATING;
                if (!(evtimer_pending(&trust_anchor_timer, NULL)))
                        evtimer_add(&trust_anchor_timer, &tv);
         } else if (rcode == LDNS_RCODE_NOERROR &&
            LDNS_RCODE_WIRE((uint8_t*)answer_packet) == LDNS_RCODE_NOERROR) {
                log_debug("%s: why bogus: %s", __func__, why_bogus);
-               data->res->state = RESOLVING;
+               checked_resolver->state = RESOLVING;
                /* best effort */
-               data->res->why_bogus = strdup(why_bogus);
+               checked_resolver->why_bogus = strdup(why_bogus);
        } else
-               data->res->state = DEAD; /* we know the root exists */
+               checked_resolver->state = DEAD; /* we know the root exists */
 
 out:
-       if (!data->res->stop && data->res->state == DEAD) {
+       if (!checked_resolver->stop && checked_resolver->state == DEAD) {
                if (prev_state == DEAD)
-                       data->res->check_tv.tv_sec *= 2;
+                       checked_resolver->check_tv.tv_sec *= 2;
                else
-                       data->res->check_tv.tv_sec = RESOLVER_CHECK_SEC;
+                       checked_resolver->check_tv.tv_sec = RESOLVER_CHECK_SEC;
 
-               if (data->res->check_tv.tv_sec > RESOLVER_CHECK_MAXSEC)
-                       data->res->check_tv.tv_sec = RESOLVER_CHECK_MAXSEC;
+               if (checked_resolver->check_tv.tv_sec > RESOLVER_CHECK_MAXSEC)
+                       checked_resolver->check_tv.tv_sec =
+                           RESOLVER_CHECK_MAXSEC;
 
-               evtimer_add(&data->res->check_ev, &data->res->check_tv);
+               evtimer_add(&checked_resolver->check_ev,
+                   &checked_resolver->check_tv);
 
                log_debug("%s: evtimer_add: %lld - %s: %s", __func__,
-                   data->res->check_tv.tv_sec,
-                   uw_resolver_type_str[data->res->type],
-                   uw_resolver_state_str[data->res->state]);
+                   checked_resolver->check_tv.tv_sec,
+                   uw_resolver_type_str[checked_resolver->type],
+                   uw_resolver_state_str[checked_resolver->state]);
        }
 
        log_debug("%s: %s: %s", __func__,
-           uw_resolver_type_str[data->res->type],
-           uw_resolver_state_str[data->res->state]);
-
-       log_debug("%s: %p - %p", __func__, data->res, data->res->ctx);
+           uw_resolver_type_str[checked_resolver->type],
+           uw_resolver_state_str[checked_resolver->state]);
 
-       resolver_unref(data->res);
-       data->check_res->stop = 1; /* do not free in callback */
-       resolver_unref(data->check_res);
+       log_debug("%s: %p - %p", __func__, checked_resolver,
+           checked_resolver->ctx);
 
-       free(data);
+       resolver_unref(checked_resolver);
+       res->stop = 1; /* do not free in callback */
 }
 
 void
 asr_resolve_done(struct asr_result *ar, void *arg)
 {
        struct resolver_cb_data *cb_data = arg;
-       cb_data->cb(cb_data->data, ar->ar_rcode, ar->ar_data, ar->ar_datalen, 0,
-           "");
+       cb_data->cb(cb_data->res, cb_data->data, ar->ar_rcode, ar->ar_data,
+           ar->ar_datalen, 0, "");
        free(ar->ar_data);
+       resolver_unref(cb_data->res);
        free(cb_data);
 }
 
@@ -1440,8 +1424,9 @@ ub_resolve_done(void *arg, int rcode, void 
*answer_packet, int answer_len,
     int sec, char *why_bogus, int was_ratelimited)
 {
        struct resolver_cb_data *cb_data = arg;
-       cb_data->cb(cb_data->data, rcode, answer_packet, answer_len, sec,
-           why_bogus);
+       cb_data->cb(cb_data->res, cb_data->data, rcode, answer_packet,
+           answer_len, sec, why_bogus);
+       resolver_unref(cb_data->res);
        free(cb_data);
 }
 
@@ -1779,13 +1764,10 @@ trust_anchor_resolve(void)
        if (res == NULL || res->state < VALIDATING)
                goto err;
 
-       resolver_ref(res);
-
-       if (resolve(res, ".",  LDNS_RR_TYPE_DNSKEY, LDNS_RR_CLASS_IN, res,
-           trust_anchor_resolve_done) != 0) {
-               resolver_unref(res);
+       if (resolve(res, ".",  LDNS_RR_TYPE_DNSKEY, LDNS_RR_CLASS_IN, NULL,
+           trust_anchor_resolve_done) != 0)
                goto err;
-       }
+
        return;
  err:
        evtimer_add(&trust_anchor_timer, &tv);
@@ -1798,10 +1780,9 @@ trust_anchor_timo(int fd, short events, void *arg)
 }
 
 void
-trust_anchor_resolve_done(void *arg, int rcode, void *answer_packet,
-    int answer_len, int sec, char *why_bogus)
+trust_anchor_resolve_done(struct uw_resolver *res, void *arg, int rcode,
+    void *answer_packet, int answer_len, int sec, char *why_bogus)
 {
-       struct uw_resolver      *res;
        struct ub_result        *result;
        sldns_buffer            *buf;
        struct regional         *region;
@@ -1810,8 +1791,6 @@ trust_anchor_resolve_done(void *arg, int rcode, void 
*answer_packet,
        uint16_t                 dnskey_flags;
        char                    *str, rdata_buf[1024], *ta;
 
-       res = (struct uw_resolver *)arg;
-
        if ((result = calloc(1, sizeof(*result))) == NULL)
                goto out;
 
@@ -1888,7 +1867,6 @@ trust_anchor_resolve_done(void *arg, int rcode, void 
*answer_packet,
        }
 out:
        ub_resolve_free(result);
-       resolver_unref(res);
        evtimer_add(&trust_anchor_timer, &tv);
 }
 
diff --git unwind.h unwind.h
index 9e762681ac7..c3bafc4445e 100644
--- unwind.h
+++ unwind.h
@@ -153,7 +153,6 @@ struct query_imsg {
        int              c;
        int              err;
        int              bogus;
-       void            *resolver;
        struct timespec  tp;
 };
 
 

-- 
I'm not entirely sure you are real.

Reply via email to