On Tue, Nov 12, 2019 at 05:45:38PM +0100, Florian Obser wrote:

> Did I get this right? I'd appreciate it if someone could give this a
> once over.
> 
>     Since resolve() switched to a callback mechanism all uw_resolver objects
>     pass through resolve() and either asr_resolve_done() or
>     ub_resolve_done().
>     With that we can pull resolver_ref() and resolver_unref() into those
>     functions to make the reference counting easier.
>     Only check_resolver is special since it needs to refcount the to be
>     checked resolver. But the resolver doing the actual work is
>     automatically refcounted by resolve() and *_resolve_done().
>     One last piece of the puzzle is to track the uw_resolver object in
>     cb_data so that the *_resolve_done() functions have access to it.
>     This also allowes us to remove the ad-hoc passing of the resolver in
>     query_imsg. Since the callback functions all need access to the
>     resolver that did the work we pass it in as first argument.
> 

Reviewed the code, did some tests and it all looks good.

One nit: I would have declared a typedef for the callback funtion type
to be used in the struct resolver_cb_data and the prototype and the
definition of resolve(), it makes those lines easier to read.  But ok
anyway,

        -Otto


> diff --git resolver.c resolver.c
> index d1dce2dec71..2b7d81d29fc 100644
> --- resolver.c
> +++ resolver.c
> @@ -92,14 +92,11 @@ struct uw_resolver {
>       int64_t                  histogram[nitems(histogram_limits)];
>  };
>  
> -struct check_resolver_data {
> -     struct uw_resolver      *res;
> -     struct uw_resolver      *check_res;
> -};
> -
>  struct resolver_cb_data {
> -     void    (*cb)(void *, int, void *, int, int, char *);
> -     void    *data;
> +     void                    (*cb)(struct uw_resolver *, void *, int, void *,
> +                                 int, int, char *);
> +     void                    *data;
> +     struct uw_resolver      *res;
>  };
>  
>  __dead void           resolver_shutdown(void);
> @@ -108,9 +105,10 @@ void                      
> resolver_dispatch_frontend(int, short, void *);
>  void                  resolver_dispatch_captiveportal(int, short, void *);
>  void                  resolver_dispatch_main(int, short, void *);
>  int                   resolve(struct uw_resolver *, const char*, int, int,
> -                          void*, void (*cb)(void *, int, void *, int, int,
> -                          char *));
> -void                  resolve_done(void *, int, void *, int, int, char *);
> +                          void*, void (*cb)(struct uw_resolver *, void *,
> +                          int, void *, int, int, char *));
> +void                  resolve_done(struct uw_resolver *, void *, int, void *,
> +                          int, int, char *);
>  void                  ub_resolve_done(void *, int, void *, int, int, char *,
>                            int);
>  void                  asr_resolve_done(struct asr_result *, void *);
> @@ -129,8 +127,8 @@ void                       set_forwarders_oppdot(struct 
> uw_resolver *,
>  void                  resolver_check_timo(int, short, void *);
>  void                  resolver_free_timo(int, short, void *);
>  void                  check_resolver(struct uw_resolver *);
> -void                  check_resolver_done(void *, int, void *, int, int,
> -                          char *);
> +void                  check_resolver_done(struct uw_resolver *, void *, int,
> +                          void *, int, int, char *);
>  void                  schedule_recheck_all_resolvers(void);
>  int                   check_forwarders_changed(struct uw_forwarder_head *,
>                            struct uw_forwarder_head *);
> @@ -154,8 +152,8 @@ int                        
> check_captive_portal_changed(struct uw_conf *,
>                            struct uw_conf *);
>  void                  trust_anchor_resolve(void);
>  void                  trust_anchor_timo(int, short, void *);
> -void                  trust_anchor_resolve_done(void *, int, void *, int,
> -                          int, char *);
> +void                  trust_anchor_resolve_done(struct uw_resolver *, void *,
> +                          int, void *, int, int, char *);
>  void                  add_autoconf_forwarders(struct imsg_rdns_proposal *);
>  void                  rem_autoconf_forwarders(struct imsg_rdns_proposal *);
>  struct uw_forwarder  *find_forwarder(struct uw_forwarder_head *,
> @@ -480,14 +478,10 @@ resolver_dispatch_frontend(int fd, short event, void 
> *bula)
>                       log_debug("%s: choosing %s", __func__,
>                           uw_resolver_type_str[res->type]);
>  
> -                     query_imsg->resolver = res;
> -                     resolver_ref(res);
> -
>                       clock_gettime(CLOCK_MONOTONIC, &query_imsg->tp);
>  
> -                     if ((resolve(res, query_imsg->qname, query_imsg->t,
> -                         query_imsg->c, query_imsg, resolve_done)) != 0)
> -                             resolver_unref(res);
> +                     resolve(res, query_imsg->qname, query_imsg->t,
> +                         query_imsg->c, query_imsg, resolve_done);
>                       break;
>               case IMSG_FORWARDER:
>                       /* make sure this is a string */
> @@ -773,16 +767,20 @@ resolver_dispatch_main(int fd, short event, void *bula)
>  
>  int
>  resolve(struct uw_resolver *res, const char* name, int rrtype, int rrclass,
> -    void *mydata, void (*cb)(void *, int, void *, int, int, char *))
> +    void *mydata, void (*cb)(struct uw_resolver *, void *, int, void *, int,
> +    int, char *))
>  {
>       struct resolver_cb_data *cb_data = NULL;
>       struct asr_query        *aq = NULL;
>       int                      err;
>  
> +     resolver_ref(res);
> +
>       if ((cb_data = malloc(sizeof(*cb_data))) == NULL)
>               goto err;
>       cb_data->cb = cb;
>       cb_data->data = mydata;
> +     cb_data->res = res;
>  
>       switch(res->type) {
>       case UW_RES_ASR:
> @@ -816,15 +814,15 @@ resolve(struct uw_resolver *res, const char* name, int 
> rrtype, int rrclass,
>   err:
>       free(cb_data);
>       free(aq);
> +     resolver_unref(res);
>       return 1;
>  }
>  
>  void
> -resolve_done(void *arg, int rcode, void *answer_packet, int answer_len,
> -    int sec, char *why_bogus)
> +resolve_done(struct uw_resolver *res, void *arg, int rcode,
> +    void *answer_packet, int answer_len, int sec, char *why_bogus)
>  {
>       struct query_imsg       *query_imsg;
> -     struct uw_resolver      *res;
>       struct timespec          tp, elapsed;
>       int64_t                  ms;
>       size_t                   i;
> @@ -833,7 +831,6 @@ resolve_done(void *arg, int rcode, void *answer_packet, 
> int answer_len,
>       clock_gettime(CLOCK_MONOTONIC, &tp);
>  
>       query_imsg = (struct query_imsg *)arg;
> -     res = (struct uw_resolver *)query_imsg->resolver;
>  
>       timespecsub(&tp, &query_imsg->tp, &elapsed);
>  
> @@ -886,7 +883,6 @@ resolve_done(void *arg, int rcode, void *answer_packet, 
> int answer_len,
>           answer_packet, answer_len);
>  
>       free(query_imsg);
> -     resolver_unref(res);
>       return;
>  
>  servfail:
> @@ -894,7 +890,6 @@ servfail:
>       resolver_imsg_compose_frontend(IMSG_ANSWER_HEADER, 0, query_imsg,
>           sizeof(*query_imsg));
>       free(query_imsg);
> -     resolver_unref(res);
>  }
>  
>  void
> @@ -1259,95 +1254,83 @@ resolver_free_timo(int fd, short events, void *arg)
>  }
>  
>  void
> -check_resolver(struct uw_resolver *res)
> +check_resolver(struct uw_resolver *resolver_to_check)
>  {
> -     struct uw_resolver              *check_res;
> -     struct check_resolver_data      *data;
> +     struct uw_resolver              *res;
>  
>       log_debug("%s: create_resolver", __func__);
> -     if ((check_res = create_resolver(res->type, 0)) == NULL)
> -             fatal("%s", __func__);
> -     if ((data = malloc(sizeof(*data))) == NULL)
> +     if ((res = create_resolver(resolver_to_check->type, 0)) == NULL)
>               fatal("%s", __func__);
>  
> -     resolver_ref(check_res);
> -     resolver_ref(res);
> -     data->check_res = check_res;
> -     data->res = res;
> +     resolver_ref(resolver_to_check);
>  
> -     if (resolve(check_res, ".", LDNS_RR_TYPE_NS, LDNS_RR_CLASS_IN,
> -         data, check_resolver_done) != 0) {
> -             res->state = UNKNOWN;
> -             resolver_unref(check_res);
> -             resolver_unref(res);
> -             res->check_tv.tv_sec = RESOLVER_CHECK_SEC;
> -             evtimer_add(&res->check_ev, &res->check_tv);
> +     if (resolve(res, ".", LDNS_RR_TYPE_NS, LDNS_RR_CLASS_IN,
> +         resolver_to_check, check_resolver_done) != 0) {
> +             resolver_to_check->state = UNKNOWN;
> +             resolver_unref(resolver_to_check);
> +             resolver_to_check->check_tv.tv_sec = RESOLVER_CHECK_SEC;
> +             evtimer_add(&resolver_to_check->check_ev,
> +                 &resolver_to_check->check_tv);
>  
>               log_debug("%s: evtimer_add: %lld - %s: %s", __func__,
> -                 data->res->check_tv.tv_sec,
> -                 uw_resolver_type_str[data->res->type],
> -                 uw_resolver_state_str[data->res->state]);
> +                 resolver_to_check->check_tv.tv_sec,
> +                 uw_resolver_type_str[resolver_to_check->type],
> +                 uw_resolver_state_str[resolver_to_check->state]);
>       }
>  
> -     if (!(res->type == UW_RES_DHCP || res->type == UW_RES_FORWARDER))
> +     if (!(resolver_to_check->type == UW_RES_DHCP ||
> +         resolver_to_check->type == UW_RES_FORWARDER))
>               return;
>  
>       log_debug("%s: create_resolver for oppdot", __func__);
> -     if ((check_res = create_resolver(res->type, 1)) == NULL)
> -             fatal("%s", __func__);
> -     if ((data = malloc(sizeof(*data))) == NULL)
> +     if ((res = create_resolver(resolver_to_check->type, 1)) == NULL)
>               fatal("%s", __func__);
>  
> -     resolver_ref(check_res);
> -     resolver_ref(res);
> -     data->check_res = check_res;
> -     data->res = res;
> +     resolver_ref(resolver_to_check);
>  
> -     if (resolve(check_res, ".", LDNS_RR_TYPE_NS, LDNS_RR_CLASS_IN,
> -         data, check_resolver_done) != 0) {
> +     if (resolve(res, ".", LDNS_RR_TYPE_NS, LDNS_RR_CLASS_IN,
> +         resolver_to_check, check_resolver_done) != 0) {
>               log_debug("check oppdot failed");
>               /* do not overwrite normal DNS state, it might work */
> -             resolver_unref(check_res);
> -             resolver_unref(res);
> +             resolver_unref(resolver_to_check);
>  
> -             res->check_tv.tv_sec = RESOLVER_CHECK_SEC;
> -             evtimer_add(&res->check_ev, &res->check_tv);
> +             resolver_to_check->check_tv.tv_sec = RESOLVER_CHECK_SEC;
> +             evtimer_add(&resolver_to_check->check_ev,
> +                 &resolver_to_check->check_tv);
>  
>               log_debug("%s: evtimer_add: %lld - %s: %s", __func__,
> -                 data->res->check_tv.tv_sec,
> -                 uw_resolver_type_str[data->res->type],
> -                 uw_resolver_state_str[data->res->state]);
> +                 resolver_to_check->check_tv.tv_sec,
> +                 uw_resolver_type_str[resolver_to_check->type],
> +                 uw_resolver_state_str[resolver_to_check->state]);
>       }
>  }
>  
>  void
> -check_resolver_done(void *arg, int rcode, void *answer_packet, int 
> answer_len,
> -    int sec, char *why_bogus)
> +check_resolver_done(struct uw_resolver *res, void *arg, int rcode,
> +    void *answer_packet, int answer_len, int sec, char *why_bogus)
>  {
> -     struct check_resolver_data      *data;
> -     struct timeval                   tv = {0, 1};
> -     enum uw_resolver_state           prev_state;
> -     char                            *str;
> -
> -     data = (struct check_resolver_data *)arg;
> +     struct uw_resolver      *checked_resolver = arg;
> +     struct timeval           tv = {0, 1};
> +     enum uw_resolver_state   prev_state;
> +     char                    *str;
>  
>       log_debug("%s: %s rcode: %d", __func__,
> -         uw_resolver_type_str[data->res->type], rcode);
> +         uw_resolver_type_str[checked_resolver->type], rcode);
>  
> -     prev_state = data->res->state;
> +     prev_state = checked_resolver->state;
>  
>       if (answer_len < LDNS_HEADER_SIZE) {
> -             data->res->state = DEAD;
> +             checked_resolver->state = DEAD;
>               log_warnx("bad packet: too short");
>               goto out;
>       }
>  
>       if (rcode == LDNS_RCODE_SERVFAIL) {
> -             if (data->check_res->oppdot == data->res->oppdot) {
> -                     data->res->state = DEAD;
> -                     if (data->res->oppdot) {
> +             if (res->oppdot == checked_resolver->oppdot) {
> +                     checked_resolver->state = DEAD;
> +                     if (checked_resolver->oppdot) {
>                               /* downgrade from opportunistic DoT */
> -                             switch (data->res->type) {
> +                             switch (checked_resolver->type) {
>                               case UW_RES_DHCP:
>                                       new_forwarders(0);
>                                       break;
> @@ -1362,9 +1345,9 @@ check_resolver_done(void *arg, int rcode, void 
> *answer_packet, int answer_len,
>               goto out;
>       }
>  
> -     if (data->check_res->oppdot && !data->res->oppdot) {
> +     if (res->oppdot && !checked_resolver->oppdot) {
>               /* upgrade to opportunistic DoT */
> -             switch (data->res->type) {
> +             switch (checked_resolver->type) {
>               case UW_RES_DHCP:
>                       new_forwarders(1);
>                       break;
> @@ -1382,56 +1365,57 @@ check_resolver_done(void *arg, int rcode, void 
> *answer_packet, int answer_len,
>       }
>  
>       if (sec == SECURE) {
> -             data->res->state = VALIDATING;
> +             checked_resolver->state = VALIDATING;
>               if (!(evtimer_pending(&trust_anchor_timer, NULL)))
>                       evtimer_add(&trust_anchor_timer, &tv);
>        } else if (rcode == LDNS_RCODE_NOERROR &&
>           LDNS_RCODE_WIRE((uint8_t*)answer_packet) == LDNS_RCODE_NOERROR) {
>               log_debug("%s: why bogus: %s", __func__, why_bogus);
> -             data->res->state = RESOLVING;
> +             checked_resolver->state = RESOLVING;
>               /* best effort */
> -             data->res->why_bogus = strdup(why_bogus);
> +             checked_resolver->why_bogus = strdup(why_bogus);
>       } else
> -             data->res->state = DEAD; /* we know the root exists */
> +             checked_resolver->state = DEAD; /* we know the root exists */
>  
>  out:
> -     if (!data->res->stop && data->res->state == DEAD) {
> +     if (!checked_resolver->stop && checked_resolver->state == DEAD) {
>               if (prev_state == DEAD)
> -                     data->res->check_tv.tv_sec *= 2;
> +                     checked_resolver->check_tv.tv_sec *= 2;
>               else
> -                     data->res->check_tv.tv_sec = RESOLVER_CHECK_SEC;
> +                     checked_resolver->check_tv.tv_sec = RESOLVER_CHECK_SEC;
>  
> -             if (data->res->check_tv.tv_sec > RESOLVER_CHECK_MAXSEC)
> -                     data->res->check_tv.tv_sec = RESOLVER_CHECK_MAXSEC;
> +             if (checked_resolver->check_tv.tv_sec > RESOLVER_CHECK_MAXSEC)
> +                     checked_resolver->check_tv.tv_sec =
> +                         RESOLVER_CHECK_MAXSEC;
>  
> -             evtimer_add(&data->res->check_ev, &data->res->check_tv);
> +             evtimer_add(&checked_resolver->check_ev,
> +                 &checked_resolver->check_tv);
>  
>               log_debug("%s: evtimer_add: %lld - %s: %s", __func__,
> -                 data->res->check_tv.tv_sec,
> -                 uw_resolver_type_str[data->res->type],
> -                 uw_resolver_state_str[data->res->state]);
> +                 checked_resolver->check_tv.tv_sec,
> +                 uw_resolver_type_str[checked_resolver->type],
> +                 uw_resolver_state_str[checked_resolver->state]);
>       }
>  
>       log_debug("%s: %s: %s", __func__,
> -         uw_resolver_type_str[data->res->type],
> -         uw_resolver_state_str[data->res->state]);
> -
> -     log_debug("%s: %p - %p", __func__, data->res, data->res->ctx);
> +         uw_resolver_type_str[checked_resolver->type],
> +         uw_resolver_state_str[checked_resolver->state]);
>  
> -     resolver_unref(data->res);
> -     data->check_res->stop = 1; /* do not free in callback */
> -     resolver_unref(data->check_res);
> +     log_debug("%s: %p - %p", __func__, checked_resolver,
> +         checked_resolver->ctx);
>  
> -     free(data);
> +     resolver_unref(checked_resolver);
> +     res->stop = 1; /* do not free in callback */
>  }
>  
>  void
>  asr_resolve_done(struct asr_result *ar, void *arg)
>  {
>       struct resolver_cb_data *cb_data = arg;
> -     cb_data->cb(cb_data->data, ar->ar_rcode, ar->ar_data, ar->ar_datalen, 0,
> -         "");
> +     cb_data->cb(cb_data->res, cb_data->data, ar->ar_rcode, ar->ar_data,
> +         ar->ar_datalen, 0, "");
>       free(ar->ar_data);
> +     resolver_unref(cb_data->res);
>       free(cb_data);
>  }
>  
> @@ -1440,8 +1424,9 @@ ub_resolve_done(void *arg, int rcode, void 
> *answer_packet, int answer_len,
>      int sec, char *why_bogus, int was_ratelimited)
>  {
>       struct resolver_cb_data *cb_data = arg;
> -     cb_data->cb(cb_data->data, rcode, answer_packet, answer_len, sec,
> -         why_bogus);
> +     cb_data->cb(cb_data->res, cb_data->data, rcode, answer_packet,
> +         answer_len, sec, why_bogus);
> +     resolver_unref(cb_data->res);
>       free(cb_data);
>  }
>  
> @@ -1779,13 +1764,10 @@ trust_anchor_resolve(void)
>       if (res == NULL || res->state < VALIDATING)
>               goto err;
>  
> -     resolver_ref(res);
> -
> -     if (resolve(res, ".",  LDNS_RR_TYPE_DNSKEY, LDNS_RR_CLASS_IN, res,
> -         trust_anchor_resolve_done) != 0) {
> -             resolver_unref(res);
> +     if (resolve(res, ".",  LDNS_RR_TYPE_DNSKEY, LDNS_RR_CLASS_IN, NULL,
> +         trust_anchor_resolve_done) != 0)
>               goto err;
> -     }
> +
>       return;
>   err:
>       evtimer_add(&trust_anchor_timer, &tv);
> @@ -1798,10 +1780,9 @@ trust_anchor_timo(int fd, short events, void *arg)
>  }
>  
>  void
> -trust_anchor_resolve_done(void *arg, int rcode, void *answer_packet,
> -    int answer_len, int sec, char *why_bogus)
> +trust_anchor_resolve_done(struct uw_resolver *res, void *arg, int rcode,
> +    void *answer_packet, int answer_len, int sec, char *why_bogus)
>  {
> -     struct uw_resolver      *res;
>       struct ub_result        *result;
>       sldns_buffer            *buf;
>       struct regional         *region;
> @@ -1810,8 +1791,6 @@ trust_anchor_resolve_done(void *arg, int rcode, void 
> *answer_packet,
>       uint16_t                 dnskey_flags;
>       char                    *str, rdata_buf[1024], *ta;
>  
> -     res = (struct uw_resolver *)arg;
> -
>       if ((result = calloc(1, sizeof(*result))) == NULL)
>               goto out;
>  
> @@ -1888,7 +1867,6 @@ trust_anchor_resolve_done(void *arg, int rcode, void 
> *answer_packet,
>       }
>  out:
>       ub_resolve_free(result);
> -     resolver_unref(res);
>       evtimer_add(&trust_anchor_timer, &tv);
>  }
>  
> diff --git unwind.h unwind.h
> index 9e762681ac7..c3bafc4445e 100644
> --- unwind.h
> +++ unwind.h
> @@ -153,7 +153,6 @@ struct query_imsg {
>       int              c;
>       int              err;
>       int              bogus;
> -     void            *resolver;
>       struct timespec  tp;
>  };
>  
>  
> 
> -- 
> I'm not entirely sure you are real.
> 

Reply via email to