Hi Jukka,

> diff --git a/src/dnsproxy.c b/src/dnsproxy.c
> index d03b734..a9bab01 100644
> --- a/src/dnsproxy.c
> +++ b/src/dnsproxy.c
> @@ -90,6 +90,7 @@ struct server_data {
>       gboolean enabled;
>       gboolean connected;
>       struct partial_reply *incoming_reply;
> +     GHashTable *cache;
>  };
>  
>  struct request_data {
> @@ -123,6 +124,33 @@ struct listener_data {
>       guint tcp_listener_watch;
>  };
>  
> +struct cache_entry {
> +     GHashTable *cache;

why do you have a cache hast table inside the cache entry?

> +     unsigned char *key;
> +     int type;
> +     int timeout;
> +     int answers;
> +     unsigned int data_len;
> +     unsigned char *data; /* contains dns header + body */
> +};
> +
> +/*
> + * We limit the cache size to some sane value so that cached data does
> + * not occupy too much memory. Each cached entry occupies on average
> + * about 100 bytes memory (depending on dns name length).
> + * Example: caching www.connman.net uses 93 bytes memory.
> + * The value is the max amount of cached dns responses (count).
> + */
> +#define MAX_CACHE_SIZE 256
> +static int cache_size;
> +
> +/*
> + * We limit how long the cached dns entry stays in the cache.
> + * By default the TTL (time-to-live) of the dns response is used
> + * when setting the cache entry life time. The value is in seconds.
> + */
> +#define MAX_CACHE_TTL (60 * 30)
> +
>  static GSList *server_list = NULL;
>  static GSList *request_list = NULL;
>  static GSList *request_pending_list = NULL;
> @@ -187,6 +215,34 @@ static struct server_data *find_server(const char 
> *interface,
>       return NULL;
>  }
>  
> +static void send_cached_response(int sk, unsigned char *buf, int len,
> +                             const struct sockaddr *to, socklen_t tolen,
> +                             int protocol, int id, int answers)
> +{
> +     struct domain_hdr *hdr;
> +     int err, offset = protocol_offset(protocol);
> +
> +     if (offset < 0)
> +             return;
> +
> +     if (len < 12)
> +             return;
> +
> +     hdr = (void *) (buf + offset);
> +
> +     hdr->id = id;
> +     hdr->qr = 1;
> +     hdr->rcode = 0;
> +     hdr->ancount = htons(answers);
> +     hdr->nscount = 0;
> +     hdr->arcount = 0;
> +
> +     DBG("id 0x%04x answers %d", hdr->id, answers);
> +
> +     err = sendto(sk, buf, len, 0, to, tolen);
> +     if (err < 0)
> +             return;
> +}
>  
>  static void send_response(int sk, unsigned char *buf, int len,
>                               const struct sockaddr *to, socklen_t tolen,
> @@ -325,12 +381,301 @@ static int append_query(unsigned char *buf, unsigned 
> int size,
>       return ptr - buf;
>  }
>  
> +static struct cache_entry *cache_check(struct server_data *server,
> +                                     gpointer request)
> +{
> +     uint16_t type;
> +     unsigned int offset = 0;
> +     unsigned char *dns_body = (unsigned char *)request + 12;
> +     unsigned char *question = &dns_body[offset];
> +     struct cache_entry *entry;
> +
> +     offset = strlen((char *)question) + 1;
> +     type = ntohs(*(uint16_t *)(&dns_body[offset]));

Are we sure that we do not need unaligned access here?

> +     /* ATM we only cache either A (1) or AAAA (28) requests */
> +     if (type != 1 && type != 28)
> +             return NULL;
> +
> +     entry = g_hash_table_lookup(server->cache, question);
> +     if (entry == NULL || entry->type != type)
> +             return NULL;
> +
> +     return entry;
> +}
> +
> +static gboolean cache_entry_timeout(gpointer user_data)
> +{
> +     struct cache_entry *entry = user_data;
> +
> +     DBG("cache %d key \"%s\"", cache_size - 1, entry->key);
> +
> +     g_hash_table_remove(entry->cache, entry->key);
> +     return FALSE;
> +}
> +
> +static int get_name(int counter, unsigned char *pkt, unsigned char *start,
> +             unsigned char *max, unsigned char *name, int max_name,
> +             int *name_len, unsigned char **end)
> +{
> +     unsigned char *p;
> +
> +     /* Limit recursion to 10 (this means up to 10 labels in domain name) */
> +     if (counter > 10)
> +             return -EINVAL;
> +
> +     p = start;
> +     while (*p) {
> +             if (*p & 0xc0) {
> +                     uint16_t offset = (*p & 0x3F) * 256 + *(p + 1);
> +
> +                     if (offset >= max - pkt)
> +                             return -ENOBUFS;
> +
> +                     if (*end == NULL)
> +                             *end = p + 2;
> +
> +                     return get_name(counter+1, pkt, pkt + offset, max,
> +                                     name, max_name, name_len, end);
> +             } else {
> +                     unsigned label_len = *p;
> +
> +                     if (pkt+label_len > max)
> +                             return -ENOBUFS;
> +
> +                     if (*name_len > max_name)
> +                             return -ENOBUFS;
> +
> +                     /* We compress the result and use pointers */
> +                     name[0] = 0xC0;
> +                     name[1] = 0x0C;
> +                     *name_len = 2;
> +
> +                     p += label_len + 1;
> +
> +                     if (*end == NULL)
> +                             *end = p;
> +
> +                     if (p >= max)
> +                             return -ENOBUFS;
> +             }
> +     }
> +
> +     return 0;
> +}
> +
> +static int parse_rr(unsigned char *buf, unsigned char *start,
> +             unsigned char *max, unsigned char *response,
> +             unsigned int *response_size, int *len,
> +             int *type, int *class, int *ttl,
> +             unsigned char **end)
> +{
> +     int err, _rdlen;
> +     int name_len = 0, max_rsp = *response_size;
> +     int offset;
> +
> +     err = get_name(0, buf, start, max, response, max_rsp, &name_len, end);
> +     if (err != 0)
> +             return err;
> +
> +     offset = name_len;
> +
> +     *type = ntohs(*(uint16_t *)(*end));
> +     *(uint16_t *)(response + offset) = *(uint16_t *)(*end);
> +     offset += 2;
> +     (*end) += 2;
> +
> +     *class = ntohs(*(uint16_t *)(*end));
> +     *(uint16_t *)(response + offset) = *(uint16_t *)(*end);
> +     offset += 2;
> +     (*end) += 2;
> +
> +     *ttl = ntohl(*(uint32_t *)(*end));
> +     *(uint32_t *)(response + offset) = *(uint32_t *)(*end);
> +     offset += 4;
> +     (*end) += 4;
> +
> +     _rdlen = ntohs(*(uint16_t *)(*end));
> +     *(uint16_t *)(response + offset) = *(uint16_t *)(*end);
> +     offset += 2;
> +     (*end) += 2;
> +
> +     memcpy(response + offset, *end, _rdlen);
> +     offset += _rdlen;
> +     (*end) += _rdlen;
> +
> +     *response_size = offset;
> +
> +     return 0;
> +}
> +
> +static int parse_response(unsigned char *buf, int buflen,
> +                     unsigned char *question, int qlen,
> +                     int *type, int *class, int *ttl,
> +                     unsigned char *response, unsigned int *resp_len,
> +                     int *answers)
> +{
> +     struct domain_hdr *hdr = (void *) buf;
> +     unsigned char *ptr;
> +     uint16_t qdcount = ntohs(hdr->qdcount);
> +     uint16_t ancount = ntohs(hdr->ancount);
> +     int err, i, qtype, qclass;
> +     unsigned char *next = NULL;
> +
> +     if (buflen < 12)
> +             return -EINVAL;
> +
> +     DBG("qr %d qdcount %d", hdr->qr, qdcount);
> +
> +     /* We currently only cache responses where question count is 1 */
> +     if (hdr->qr != 1 || qdcount != 1)
> +             return -EINVAL;
> +
> +     ptr = buf + sizeof(struct domain_hdr);
> +
> +     strncpy((char *)question, (char *)ptr, qlen);
> +     qlen = strlen((char *)question);
> +     ptr += qlen + 1;
> +     qtype = ntohs(*(uint16_t *)ptr);
> +
> +     /* We cache only A and AAAA records */
> +     if (qtype != 1 && qtype != 28)
> +             return -ENOMSG;
> +
> +     ptr += 2;
> +     qclass = ntohs(*(uint16_t *)ptr);
> +     ptr += 2; /* ptr points now to answers */
> +
> +     err = -ENOMSG;
> +     *resp_len = 0;
> +     *answers = 0;
> +
> +     for (i = 0; i < ancount; i++) {
> +             unsigned char rsp[128];
> +             unsigned int rsp_len = sizeof(rsp) - 1;
> +             int len = 0;
> +
> +             memset(rsp, 0, sizeof(rsp));
> +             err = parse_rr(buf, ptr, buf + buflen, rsp, &rsp_len, &len,
> +                     type, class, ttl, &next);
> +             if (err != 0)
> +                     return err;
> +
> +             if ((strncmp((char *)question, (char *)rsp, len) == 0) &&
> +                             *type == qtype && *class == qclass) {

Should we also check that the len of question and rsp is well in the
limits here.

> +                     memcpy(response + *resp_len, rsp, rsp_len);
> +                     *resp_len += rsp_len;
> +                     err = 0;
> +                     (*answers)++;
> +             }
> +
> +             ptr = next;
> +             next = NULL;
> +     }
> +
> +     return err;
> +}
> +
> +static int cache_update(struct server_data *data, unsigned char *msg,
> +                     unsigned int msg_len)
> +{
> +     int offset = protocol_offset(data->protocol);
> +     struct cache_entry *entry;
> +     unsigned char question[256];
> +     unsigned char response[256];
> +     int err, type = 0, class = 0, ttl = 0, answers;
> +     unsigned int rsplen;
> +
> +     if (cache_size > MAX_CACHE_SIZE)
> +             return 0;
> +
> +     /* Continue only if response code is 0 (=ok) */
> +     if (msg[3] & 0x0f)
> +             return 0;
> +
> +     if (offset < 0)
> +             return 0;
> +
> +     rsplen = sizeof(response) - 1;
> +     err = parse_response(msg + offset, msg_len - offset,
> +                             question, sizeof(question) - 1,
> +                             &type, &class, &ttl,
> +                             response, &rsplen, &answers);
> +     if (err < 0)
> +             return 0;
> +
> +     if (g_hash_table_lookup(data->cache, question) == NULL) {

Why not just != NULL and just leave the function then?

> +             int qlen = strlen((char *)question);
> +             unsigned char *ptr;
> +
> +             entry = g_try_new0(struct cache_entry, 1);

When using g_try_new0 you need to check the result and abort nicely.

> +             entry->cache = data->cache;
> +             entry->key = (unsigned char *)g_strdup((char *)question);
> +             entry->type = type;
> +             entry->answers = answers;
> +             entry->data_len = 12 + qlen + 1 + 2 + 2 + rsplen;
> +             entry->data = ptr = g_malloc0(entry->data_len);

Here we really wanna check the result and abort nicely.

> +
> +             memcpy(ptr, msg, 12);
> +             ptr += 12;
> +             memcpy(ptr, question, qlen);
> +             ptr += qlen + 1;
> +             *(uint16_t *)ptr = htons(type);
> +             ptr += 2;
> +             *(uint16_t *)ptr = htons(class);
> +             ptr += 2;
> +             memcpy(ptr, response, rsplen);

If you copy data in the allocated buffer anyway, then setting it to zero
first is waste.

And no need to keep moving the pointer forward. Just use memcpy(ptr + ..
style operation.

Also what about unaligned access. Do we need to care?

> +             /*
> +              * Restrict the cached dns entry ttl to some sane value
> +              * in order to prevent data staying in the cache too long.
> +              */
> +             if (ttl > MAX_CACHE_TTL)
> +                     ttl = MAX_CACHE_TTL;
> +
> +             entry->timeout = g_timeout_add_seconds(ttl,
> +                                             cache_entry_timeout,
> +                                             entry);
> +             g_hash_table_insert(data->cache, g_strdup((char *)entry->key),
> +                             entry);

The double allocation on entry->key is not a good idea.

> +             cache_size++;
> +
> +             DBG("cache %d question \"%s\" type %d ttl %d len %d",
> +                     cache_size, question, type, ttl,
> +                     sizeof(*entry) + entry->data_len + qlen);
> +     }
> +
> +     return 0;
> +}
> +
>  static int ns_resolv(struct server_data *server, struct request_data *req,
>                               gpointer request, gpointer name)
>  {
>       GList *list;
>       int sk, err;
>       char *dot, *lookup = (char *) name;
> +     struct cache_entry *entry;
> +
> +     entry = cache_check(server, request);
> +     if (entry != NULL) {
> +             DBG("cache hit %s", lookup);
> +
> +             if (req->protocol == IPPROTO_TCP) {
> +                     send_cached_response(req->client_sk, entry->data,
> +                                     entry->data_len, NULL, 0, IPPROTO_TCP,
> +                                     req->srcid, entry->answers);
> +                     return 1;
> +             } else if (req->protocol == IPPROTO_UDP) {
> +                     int sk;
> +                     sk = g_io_channel_unix_get_fd(
> +                                     req->ifdata->udp_listener_channel);
> +
> +                     send_cached_response(sk, entry->data,
> +                             entry->data_len, &req->sa, req->sa_len,
> +                             IPPROTO_UDP, req->srcid, entry->answers);
> +                     return 1;
> +             }
> +     }
>  
>       sk = g_io_channel_unix_get_fd(server->channel);
>  
> @@ -396,7 +741,8 @@ static int ns_resolv(struct server_data *server, struct 
> request_data *req,
>       return 0;
>  }
>  
> -static int forward_dns_reply(unsigned char *reply, int reply_len, int 
> protocol)
> +static int forward_dns_reply(unsigned char *reply, int reply_len, int 
> protocol,
> +                             struct server_data *data)
>  {
>       struct domain_hdr *hdr;
>       struct request_data *req;
> @@ -434,6 +780,8 @@ static int forward_dns_reply(unsigned char *reply, int 
> reply_len, int protocol)
>  
>               memcpy(req->resp, reply, reply_len);
>               req->resplen = reply_len;
> +
> +             cache_update(data, reply, reply_len);
>       }
>  
>       if (hdr->rcode > 0 && req->numresp < req->numserv)
> @@ -460,6 +808,24 @@ static int forward_dns_reply(unsigned char *reply, int 
> reply_len, int protocol)
>       return err;
>  }
>  
> +static void cache_element_destroy(gpointer value)
> +{
> +     struct cache_entry *entry = value;
> +
> +     if (entry == NULL)
> +             return;
> +
> +     if (entry->timeout > 0)
> +             g_source_remove(entry->timeout);
> +
> +     g_free(entry->key);
> +     g_free(entry->data);
> +     g_free(entry);
> +
> +     if (--cache_size < 0)
> +             cache_size = 0;
> +}
> +
>  static void destroy_server(struct server_data *server)
>  {
>       GList *list;
> @@ -488,6 +854,9 @@ static void destroy_server(struct server_data *server)
>               g_free(domain);
>       }
>       g_free(server->interface);
> +
> +     g_hash_table_destroy(server->cache);
> +
>       g_free(server);
>  }
>  
> @@ -496,10 +865,9 @@ static gboolean udp_server_event(GIOChannel *channel, 
> GIOCondition condition,
>  {
>       unsigned char buf[4096];
>       int sk, err, len;
> +     struct server_data *data = user_data;
>  
>       if (condition & (G_IO_NVAL | G_IO_ERR | G_IO_HUP)) {
> -             struct server_data *data = user_data;
> -
>               connman_error("Error with UDP server %s", data->server);
>               data->watch = 0;
>               return FALSE;
> @@ -511,7 +879,7 @@ static gboolean udp_server_event(GIOChannel *channel, 
> GIOCondition condition,
>       if (len < 12)
>               return TRUE;
>  
> -     err = forward_dns_reply(buf, len, IPPROTO_UDP);
> +     err = forward_dns_reply(buf, len, IPPROTO_UDP, data);
>       if (err < 0)
>               return TRUE;
>  
> @@ -612,7 +980,16 @@ hangup:
>  
>                       req->timeout = g_timeout_add_seconds(30,
>                                               request_timeout, req);
> -                     ns_resolv(server, req, req->request, req->name);
> +                     if (ns_resolv(server, req, req->request,
> +                                     req->name) > 0) {
> +                             /* We sent cached result so no need for timeout
> +                              * handler.
> +                              */
> +                             if (req->timeout > 0) {
> +                                     g_source_remove(req->timeout);
> +                                     req->timeout = 0;
> +                             }
> +                     }
>               }
>  
>       } else if (condition & G_IO_IN) {
> @@ -668,7 +1045,8 @@ hangup:
>                       reply->received += bytes_recv;
>               }
>  
> -             forward_dns_reply(reply->buf, reply->received, IPPROTO_TCP);
> +             forward_dns_reply(reply->buf, reply->received, IPPROTO_TCP,
> +                                     server);
>  
>               g_free(reply);
>               server->incoming_reply = NULL;
> @@ -818,6 +1196,11 @@ static struct server_data *create_server(const char 
> *interface,
>               }
>       }
>  
> +     data->cache = g_hash_table_new_full(g_str_hash,
> +                                             g_str_equal,
> +                                             g_free,
> +                                             cache_element_destroy);

I prefer to not double allocated the keys. Just have the key point to
cache->key and free it via cache_element_destroy callback.

However you need to use g_hash_table_replace instead of
g_hash_table_insert to make this work properly. Otherwise you have
invalid memory access.

> +
>       if (protocol == IPPROTO_UDP) {
>               /* Enable new servers by default */
>               data->enabled = TRUE;
> @@ -835,6 +1218,7 @@ static gboolean resolv(struct request_data *req,
>                               gpointer request, gpointer name)
>  {
>       GSList *list;
> +     int status;
>  
>       for (list = server_list; list; list = list->next) {
>               struct server_data *data = list->data;
> @@ -849,8 +1233,15 @@ static gboolean resolv(struct request_data *req,
>                               G_IO_IN | G_IO_NVAL | G_IO_ERR | G_IO_HUP,
>                                               udp_server_event, data);
>  
> -             if (ns_resolv(data, req, request, name) < 0)
> +             status = ns_resolv(data, req, request, name);
> +             if (status < 0)
>                       continue;
> +             else if (status > 0) {
> +                     if (req->timeout > 0) {
> +                             g_source_remove(req->timeout);
> +                             req->timeout = 0;
> +                     }
> +             }
>       }
>  
>       return TRUE;
> @@ -1223,7 +1614,12 @@ static gboolean tcp_listener_event(GIOChannel 
> *channel, GIOCondition condition,
>               }
>  
>               req->timeout = g_timeout_add_seconds(30, request_timeout, req);
> -             ns_resolv(server, req, buf, query);
> +             if (ns_resolv(server, req, buf, query) > 0) {
> +                     if (req->timeout > 0) {
> +                             g_source_remove(req->timeout);
> +                             req->timeout = 0;
> +                     }
> +             }
>       }
>  
>       return TRUE;

Regards

Marcel


_______________________________________________
connman mailing list
[email protected]
http://lists.connman.net/listinfo/connman

Reply via email to