On Tue, Nov 12, 2019 at 05:45:38PM +0100, Florian Obser wrote: > Did I get this right? I'd appreciate it if someone could give this a > once over. > > Since resolve() switched to a callback mechanism all uw_resolver objects > pass through resolve() and either asr_resolve_done() or > ub_resolve_done(). > With that we can pull resolver_ref() and resolver_unref() into those > functions to make the reference counting easier. > Only check_resolver is special since it needs to refcount the to be > checked resolver. But the resolver doing the actual work is > automatically refcounted by resolve() and *_resolve_done(). > One last piece of the puzzle is to track the uw_resolver object in > cb_data so that the *_resolve_done() functions have access to it. > This also allowes us to remove the ad-hoc passing of the resolver in > query_imsg. Since the callback functions all need access to the > resolver that did the work we pass it in as first argument. >
Reviewed the code, did some tests and it all looks good. One nit: I would have declared a typedef for the callback funtion type to be used in the struct resolver_cb_data and the prototype and the definition of resolve(), it makes those lines easier to read. But ok anyway, -Otto > diff --git resolver.c resolver.c > index d1dce2dec71..2b7d81d29fc 100644 > --- resolver.c > +++ resolver.c > @@ -92,14 +92,11 @@ struct uw_resolver { > int64_t histogram[nitems(histogram_limits)]; > }; > > -struct check_resolver_data { > - struct uw_resolver *res; > - struct uw_resolver *check_res; > -}; > - > struct resolver_cb_data { > - void (*cb)(void *, int, void *, int, int, char *); > - void *data; > + void (*cb)(struct uw_resolver *, void *, int, void *, > + int, int, char *); > + void *data; > + struct uw_resolver *res; > }; > > __dead void resolver_shutdown(void); > @@ -108,9 +105,10 @@ void > resolver_dispatch_frontend(int, short, void *); > void resolver_dispatch_captiveportal(int, short, void *); > void resolver_dispatch_main(int, short, void *); > int resolve(struct uw_resolver *, const char*, int, int, > - void*, void (*cb)(void *, int, void *, int, int, > - char *)); > -void resolve_done(void *, int, void *, int, int, char *); > + void*, void (*cb)(struct uw_resolver *, void *, > + int, void *, int, int, char *)); > +void resolve_done(struct uw_resolver *, void *, int, void *, > + int, int, char *); > void ub_resolve_done(void *, int, void *, int, int, char *, > int); > void asr_resolve_done(struct asr_result *, void *); > @@ -129,8 +127,8 @@ void set_forwarders_oppdot(struct > uw_resolver *, > void resolver_check_timo(int, short, void *); > void resolver_free_timo(int, short, void *); > void check_resolver(struct uw_resolver *); > -void check_resolver_done(void *, int, void *, int, int, > - char *); > +void check_resolver_done(struct uw_resolver *, void *, int, > + void *, int, int, char *); > void schedule_recheck_all_resolvers(void); > int check_forwarders_changed(struct uw_forwarder_head *, > struct uw_forwarder_head *); > @@ -154,8 +152,8 @@ int > check_captive_portal_changed(struct uw_conf *, > struct uw_conf *); > void trust_anchor_resolve(void); > void trust_anchor_timo(int, short, void *); > -void trust_anchor_resolve_done(void *, int, void *, int, > - int, char *); > +void trust_anchor_resolve_done(struct uw_resolver *, void *, > + int, void *, int, int, char *); > void add_autoconf_forwarders(struct imsg_rdns_proposal *); > void rem_autoconf_forwarders(struct imsg_rdns_proposal *); > struct uw_forwarder *find_forwarder(struct uw_forwarder_head *, > @@ -480,14 +478,10 @@ resolver_dispatch_frontend(int fd, short event, void > *bula) > log_debug("%s: choosing %s", __func__, > uw_resolver_type_str[res->type]); > > - query_imsg->resolver = res; > - resolver_ref(res); > - > clock_gettime(CLOCK_MONOTONIC, &query_imsg->tp); > > - if ((resolve(res, query_imsg->qname, query_imsg->t, > - query_imsg->c, query_imsg, resolve_done)) != 0) > - resolver_unref(res); > + resolve(res, query_imsg->qname, query_imsg->t, > + query_imsg->c, query_imsg, resolve_done); > break; > case IMSG_FORWARDER: > /* make sure this is a string */ > @@ -773,16 +767,20 @@ resolver_dispatch_main(int fd, short event, void *bula) > > int > resolve(struct uw_resolver *res, const char* name, int rrtype, int rrclass, > - void *mydata, void (*cb)(void *, int, void *, int, int, char *)) > + void *mydata, void (*cb)(struct uw_resolver *, void *, int, void *, int, > + int, char *)) > { > struct resolver_cb_data *cb_data = NULL; > struct asr_query *aq = NULL; > int err; > > + resolver_ref(res); > + > if ((cb_data = malloc(sizeof(*cb_data))) == NULL) > goto err; > cb_data->cb = cb; > cb_data->data = mydata; > + cb_data->res = res; > > switch(res->type) { > case UW_RES_ASR: > @@ -816,15 +814,15 @@ resolve(struct uw_resolver *res, const char* name, int > rrtype, int rrclass, > err: > free(cb_data); > free(aq); > + resolver_unref(res); > return 1; > } > > void > -resolve_done(void *arg, int rcode, void *answer_packet, int answer_len, > - int sec, char *why_bogus) > +resolve_done(struct uw_resolver *res, void *arg, int rcode, > + void *answer_packet, int answer_len, int sec, char *why_bogus) > { > struct query_imsg *query_imsg; > - struct uw_resolver *res; > struct timespec tp, elapsed; > int64_t ms; > size_t i; > @@ -833,7 +831,6 @@ resolve_done(void *arg, int rcode, void *answer_packet, > int answer_len, > clock_gettime(CLOCK_MONOTONIC, &tp); > > query_imsg = (struct query_imsg *)arg; > - res = (struct uw_resolver *)query_imsg->resolver; > > timespecsub(&tp, &query_imsg->tp, &elapsed); > > @@ -886,7 +883,6 @@ resolve_done(void *arg, int rcode, void *answer_packet, > int answer_len, > answer_packet, answer_len); > > free(query_imsg); > - resolver_unref(res); > return; > > servfail: > @@ -894,7 +890,6 @@ servfail: > resolver_imsg_compose_frontend(IMSG_ANSWER_HEADER, 0, query_imsg, > sizeof(*query_imsg)); > free(query_imsg); > - resolver_unref(res); > } > > void > @@ -1259,95 +1254,83 @@ resolver_free_timo(int fd, short events, void *arg) > } > > void > -check_resolver(struct uw_resolver *res) > +check_resolver(struct uw_resolver *resolver_to_check) > { > - struct uw_resolver *check_res; > - struct check_resolver_data *data; > + struct uw_resolver *res; > > log_debug("%s: create_resolver", __func__); > - if ((check_res = create_resolver(res->type, 0)) == NULL) > - fatal("%s", __func__); > - if ((data = malloc(sizeof(*data))) == NULL) > + if ((res = create_resolver(resolver_to_check->type, 0)) == NULL) > fatal("%s", __func__); > > - resolver_ref(check_res); > - resolver_ref(res); > - data->check_res = check_res; > - data->res = res; > + resolver_ref(resolver_to_check); > > - if (resolve(check_res, ".", LDNS_RR_TYPE_NS, LDNS_RR_CLASS_IN, > - data, check_resolver_done) != 0) { > - res->state = UNKNOWN; > - resolver_unref(check_res); > - resolver_unref(res); > - res->check_tv.tv_sec = RESOLVER_CHECK_SEC; > - evtimer_add(&res->check_ev, &res->check_tv); > + if (resolve(res, ".", LDNS_RR_TYPE_NS, LDNS_RR_CLASS_IN, > + resolver_to_check, check_resolver_done) != 0) { > + resolver_to_check->state = UNKNOWN; > + resolver_unref(resolver_to_check); > + resolver_to_check->check_tv.tv_sec = RESOLVER_CHECK_SEC; > + evtimer_add(&resolver_to_check->check_ev, > + &resolver_to_check->check_tv); > > log_debug("%s: evtimer_add: %lld - %s: %s", __func__, > - data->res->check_tv.tv_sec, > - uw_resolver_type_str[data->res->type], > - uw_resolver_state_str[data->res->state]); > + resolver_to_check->check_tv.tv_sec, > + uw_resolver_type_str[resolver_to_check->type], > + uw_resolver_state_str[resolver_to_check->state]); > } > > - if (!(res->type == UW_RES_DHCP || res->type == UW_RES_FORWARDER)) > + if (!(resolver_to_check->type == UW_RES_DHCP || > + resolver_to_check->type == UW_RES_FORWARDER)) > return; > > log_debug("%s: create_resolver for oppdot", __func__); > - if ((check_res = create_resolver(res->type, 1)) == NULL) > - fatal("%s", __func__); > - if ((data = malloc(sizeof(*data))) == NULL) > + if ((res = create_resolver(resolver_to_check->type, 1)) == NULL) > fatal("%s", __func__); > > - resolver_ref(check_res); > - resolver_ref(res); > - data->check_res = check_res; > - data->res = res; > + resolver_ref(resolver_to_check); > > - if (resolve(check_res, ".", LDNS_RR_TYPE_NS, LDNS_RR_CLASS_IN, > - data, check_resolver_done) != 0) { > + if (resolve(res, ".", LDNS_RR_TYPE_NS, LDNS_RR_CLASS_IN, > + resolver_to_check, check_resolver_done) != 0) { > log_debug("check oppdot failed"); > /* do not overwrite normal DNS state, it might work */ > - resolver_unref(check_res); > - resolver_unref(res); > + resolver_unref(resolver_to_check); > > - res->check_tv.tv_sec = RESOLVER_CHECK_SEC; > - evtimer_add(&res->check_ev, &res->check_tv); > + resolver_to_check->check_tv.tv_sec = RESOLVER_CHECK_SEC; > + evtimer_add(&resolver_to_check->check_ev, > + &resolver_to_check->check_tv); > > log_debug("%s: evtimer_add: %lld - %s: %s", __func__, > - data->res->check_tv.tv_sec, > - uw_resolver_type_str[data->res->type], > - uw_resolver_state_str[data->res->state]); > + resolver_to_check->check_tv.tv_sec, > + uw_resolver_type_str[resolver_to_check->type], > + uw_resolver_state_str[resolver_to_check->state]); > } > } > > void > -check_resolver_done(void *arg, int rcode, void *answer_packet, int > answer_len, > - int sec, char *why_bogus) > +check_resolver_done(struct uw_resolver *res, void *arg, int rcode, > + void *answer_packet, int answer_len, int sec, char *why_bogus) > { > - struct check_resolver_data *data; > - struct timeval tv = {0, 1}; > - enum uw_resolver_state prev_state; > - char *str; > - > - data = (struct check_resolver_data *)arg; > + struct uw_resolver *checked_resolver = arg; > + struct timeval tv = {0, 1}; > + enum uw_resolver_state prev_state; > + char *str; > > log_debug("%s: %s rcode: %d", __func__, > - uw_resolver_type_str[data->res->type], rcode); > + uw_resolver_type_str[checked_resolver->type], rcode); > > - prev_state = data->res->state; > + prev_state = checked_resolver->state; > > if (answer_len < LDNS_HEADER_SIZE) { > - data->res->state = DEAD; > + checked_resolver->state = DEAD; > log_warnx("bad packet: too short"); > goto out; > } > > if (rcode == LDNS_RCODE_SERVFAIL) { > - if (data->check_res->oppdot == data->res->oppdot) { > - data->res->state = DEAD; > - if (data->res->oppdot) { > + if (res->oppdot == checked_resolver->oppdot) { > + checked_resolver->state = DEAD; > + if (checked_resolver->oppdot) { > /* downgrade from opportunistic DoT */ > - switch (data->res->type) { > + switch (checked_resolver->type) { > case UW_RES_DHCP: > new_forwarders(0); > break; > @@ -1362,9 +1345,9 @@ check_resolver_done(void *arg, int rcode, void > *answer_packet, int answer_len, > goto out; > } > > - if (data->check_res->oppdot && !data->res->oppdot) { > + if (res->oppdot && !checked_resolver->oppdot) { > /* upgrade to opportunistic DoT */ > - switch (data->res->type) { > + switch (checked_resolver->type) { > case UW_RES_DHCP: > new_forwarders(1); > break; > @@ -1382,56 +1365,57 @@ check_resolver_done(void *arg, int rcode, void > *answer_packet, int answer_len, > } > > if (sec == SECURE) { > - data->res->state = VALIDATING; > + checked_resolver->state = VALIDATING; > if (!(evtimer_pending(&trust_anchor_timer, NULL))) > evtimer_add(&trust_anchor_timer, &tv); > } else if (rcode == LDNS_RCODE_NOERROR && > LDNS_RCODE_WIRE((uint8_t*)answer_packet) == LDNS_RCODE_NOERROR) { > log_debug("%s: why bogus: %s", __func__, why_bogus); > - data->res->state = RESOLVING; > + checked_resolver->state = RESOLVING; > /* best effort */ > - data->res->why_bogus = strdup(why_bogus); > + checked_resolver->why_bogus = strdup(why_bogus); > } else > - data->res->state = DEAD; /* we know the root exists */ > + checked_resolver->state = DEAD; /* we know the root exists */ > > out: > - if (!data->res->stop && data->res->state == DEAD) { > + if (!checked_resolver->stop && checked_resolver->state == DEAD) { > if (prev_state == DEAD) > - data->res->check_tv.tv_sec *= 2; > + checked_resolver->check_tv.tv_sec *= 2; > else > - data->res->check_tv.tv_sec = RESOLVER_CHECK_SEC; > + checked_resolver->check_tv.tv_sec = RESOLVER_CHECK_SEC; > > - if (data->res->check_tv.tv_sec > RESOLVER_CHECK_MAXSEC) > - data->res->check_tv.tv_sec = RESOLVER_CHECK_MAXSEC; > + if (checked_resolver->check_tv.tv_sec > RESOLVER_CHECK_MAXSEC) > + checked_resolver->check_tv.tv_sec = > + RESOLVER_CHECK_MAXSEC; > > - evtimer_add(&data->res->check_ev, &data->res->check_tv); > + evtimer_add(&checked_resolver->check_ev, > + &checked_resolver->check_tv); > > log_debug("%s: evtimer_add: %lld - %s: %s", __func__, > - data->res->check_tv.tv_sec, > - uw_resolver_type_str[data->res->type], > - uw_resolver_state_str[data->res->state]); > + checked_resolver->check_tv.tv_sec, > + uw_resolver_type_str[checked_resolver->type], > + uw_resolver_state_str[checked_resolver->state]); > } > > log_debug("%s: %s: %s", __func__, > - uw_resolver_type_str[data->res->type], > - uw_resolver_state_str[data->res->state]); > - > - log_debug("%s: %p - %p", __func__, data->res, data->res->ctx); > + uw_resolver_type_str[checked_resolver->type], > + uw_resolver_state_str[checked_resolver->state]); > > - resolver_unref(data->res); > - data->check_res->stop = 1; /* do not free in callback */ > - resolver_unref(data->check_res); > + log_debug("%s: %p - %p", __func__, checked_resolver, > + checked_resolver->ctx); > > - free(data); > + resolver_unref(checked_resolver); > + res->stop = 1; /* do not free in callback */ > } > > void > asr_resolve_done(struct asr_result *ar, void *arg) > { > struct resolver_cb_data *cb_data = arg; > - cb_data->cb(cb_data->data, ar->ar_rcode, ar->ar_data, ar->ar_datalen, 0, > - ""); > + cb_data->cb(cb_data->res, cb_data->data, ar->ar_rcode, ar->ar_data, > + ar->ar_datalen, 0, ""); > free(ar->ar_data); > + resolver_unref(cb_data->res); > free(cb_data); > } > > @@ -1440,8 +1424,9 @@ ub_resolve_done(void *arg, int rcode, void > *answer_packet, int answer_len, > int sec, char *why_bogus, int was_ratelimited) > { > struct resolver_cb_data *cb_data = arg; > - cb_data->cb(cb_data->data, rcode, answer_packet, answer_len, sec, > - why_bogus); > + cb_data->cb(cb_data->res, cb_data->data, rcode, answer_packet, > + answer_len, sec, why_bogus); > + resolver_unref(cb_data->res); > free(cb_data); > } > > @@ -1779,13 +1764,10 @@ trust_anchor_resolve(void) > if (res == NULL || res->state < VALIDATING) > goto err; > > - resolver_ref(res); > - > - if (resolve(res, ".", LDNS_RR_TYPE_DNSKEY, LDNS_RR_CLASS_IN, res, > - trust_anchor_resolve_done) != 0) { > - resolver_unref(res); > + if (resolve(res, ".", LDNS_RR_TYPE_DNSKEY, LDNS_RR_CLASS_IN, NULL, > + trust_anchor_resolve_done) != 0) > goto err; > - } > + > return; > err: > evtimer_add(&trust_anchor_timer, &tv); > @@ -1798,10 +1780,9 @@ trust_anchor_timo(int fd, short events, void *arg) > } > > void > -trust_anchor_resolve_done(void *arg, int rcode, void *answer_packet, > - int answer_len, int sec, char *why_bogus) > +trust_anchor_resolve_done(struct uw_resolver *res, void *arg, int rcode, > + void *answer_packet, int answer_len, int sec, char *why_bogus) > { > - struct uw_resolver *res; > struct ub_result *result; > sldns_buffer *buf; > struct regional *region; > @@ -1810,8 +1791,6 @@ trust_anchor_resolve_done(void *arg, int rcode, void > *answer_packet, > uint16_t dnskey_flags; > char *str, rdata_buf[1024], *ta; > > - res = (struct uw_resolver *)arg; > - > if ((result = calloc(1, sizeof(*result))) == NULL) > goto out; > > @@ -1888,7 +1867,6 @@ trust_anchor_resolve_done(void *arg, int rcode, void > *answer_packet, > } > out: > ub_resolve_free(result); > - resolver_unref(res); > evtimer_add(&trust_anchor_timer, &tv); > } > > diff --git unwind.h unwind.h > index 9e762681ac7..c3bafc4445e 100644 > --- unwind.h > +++ unwind.h > @@ -153,7 +153,6 @@ struct query_imsg { > int c; > int err; > int bogus; > - void *resolver; > struct timespec tp; > }; > > > > -- > I'm not entirely sure you are real. >