---
 src/dnsproxy.c |  412 ++++++++++++++++++++++++++++++++++++++++++++++++++++++-
 1 files changed, 404 insertions(+), 8 deletions(-)

diff --git a/src/dnsproxy.c b/src/dnsproxy.c
index d03b734..a9bab01 100644
--- a/src/dnsproxy.c
+++ b/src/dnsproxy.c
@@ -90,6 +90,7 @@ struct server_data {
        gboolean enabled;
        gboolean connected;
        struct partial_reply *incoming_reply;
+       GHashTable *cache;
 };
 
 struct request_data {
@@ -123,6 +124,33 @@ struct listener_data {
        guint tcp_listener_watch;
 };
 
+struct cache_entry {
+       GHashTable *cache;
+       unsigned char *key;
+       int type;
+       int timeout;
+       int answers;
+       unsigned int data_len;
+       unsigned char *data; /* contains dns header + body */
+};
+
+/*
+ * We limit the cache size to some sane value so that cached data does
+ * not occupy too much memory. Each cached entry occupies on average
+ * about 100 bytes memory (depending on dns name length).
+ * Example: caching www.connman.net uses 93 bytes memory.
+ * The value is the max amount of cached dns responses (count).
+ */
+#define MAX_CACHE_SIZE 256
+static int cache_size;
+
+/*
+ * We limit how long the cached dns entry stays in the cache.
+ * By default the TTL (time-to-live) of the dns response is used
+ * when setting the cache entry life time. The value is in seconds.
+ */
+#define MAX_CACHE_TTL (60 * 30)
+
 static GSList *server_list = NULL;
 static GSList *request_list = NULL;
 static GSList *request_pending_list = NULL;
@@ -187,6 +215,34 @@ static struct server_data *find_server(const char 
*interface,
        return NULL;
 }
 
+static void send_cached_response(int sk, unsigned char *buf, int len,
+                               const struct sockaddr *to, socklen_t tolen,
+                               int protocol, int id, int answers)
+{
+       struct domain_hdr *hdr;
+       int err, offset = protocol_offset(protocol);
+
+       if (offset < 0)
+               return;
+
+       if (len < 12)
+               return;
+
+       hdr = (void *) (buf + offset);
+
+       hdr->id = id;
+       hdr->qr = 1;
+       hdr->rcode = 0;
+       hdr->ancount = htons(answers);
+       hdr->nscount = 0;
+       hdr->arcount = 0;
+
+       DBG("id 0x%04x answers %d", hdr->id, answers);
+
+       err = sendto(sk, buf, len, 0, to, tolen);
+       if (err < 0)
+               return;
+}
 
 static void send_response(int sk, unsigned char *buf, int len,
                                const struct sockaddr *to, socklen_t tolen,
@@ -325,12 +381,301 @@ static int append_query(unsigned char *buf, unsigned int 
size,
        return ptr - buf;
 }
 
+static struct cache_entry *cache_check(struct server_data *server,
+                                       gpointer request)
+{
+       uint16_t type;
+       unsigned int offset = 0;
+       unsigned char *dns_body = (unsigned char *)request + 12;
+       unsigned char *question = &dns_body[offset];
+       struct cache_entry *entry;
+
+       offset = strlen((char *)question) + 1;
+       type = ntohs(*(uint16_t *)(&dns_body[offset]));
+
+       /* ATM we only cache either A (1) or AAAA (28) requests */
+       if (type != 1 && type != 28)
+               return NULL;
+
+       entry = g_hash_table_lookup(server->cache, question);
+       if (entry == NULL || entry->type != type)
+               return NULL;
+
+       return entry;
+}
+
+static gboolean cache_entry_timeout(gpointer user_data)
+{
+       struct cache_entry *entry = user_data;
+
+       DBG("cache %d key \"%s\"", cache_size - 1, entry->key);
+
+       g_hash_table_remove(entry->cache, entry->key);
+       return FALSE;
+}
+
+static int get_name(int counter, unsigned char *pkt, unsigned char *start,
+               unsigned char *max, unsigned char *name, int max_name,
+               int *name_len, unsigned char **end)
+{
+       unsigned char *p;
+
+       /* Limit recursion to 10 (this means up to 10 labels in domain name) */
+       if (counter > 10)
+               return -EINVAL;
+
+       p = start;
+       while (*p) {
+               if (*p & 0xc0) {
+                       uint16_t offset = (*p & 0x3F) * 256 + *(p + 1);
+
+                       if (offset >= max - pkt)
+                               return -ENOBUFS;
+
+                       if (*end == NULL)
+                               *end = p + 2;
+
+                       return get_name(counter+1, pkt, pkt + offset, max,
+                                       name, max_name, name_len, end);
+               } else {
+                       unsigned label_len = *p;
+
+                       if (pkt+label_len > max)
+                               return -ENOBUFS;
+
+                       if (*name_len > max_name)
+                               return -ENOBUFS;
+
+                       /* We compress the result and use pointers */
+                       name[0] = 0xC0;
+                       name[1] = 0x0C;
+                       *name_len = 2;
+
+                       p += label_len + 1;
+
+                       if (*end == NULL)
+                               *end = p;
+
+                       if (p >= max)
+                               return -ENOBUFS;
+               }
+       }
+
+       return 0;
+}
+
+static int parse_rr(unsigned char *buf, unsigned char *start,
+               unsigned char *max, unsigned char *response,
+               unsigned int *response_size, int *len,
+               int *type, int *class, int *ttl,
+               unsigned char **end)
+{
+       int err, _rdlen;
+       int name_len = 0, max_rsp = *response_size;
+       int offset;
+
+       err = get_name(0, buf, start, max, response, max_rsp, &name_len, end);
+       if (err != 0)
+               return err;
+
+       offset = name_len;
+
+       *type = ntohs(*(uint16_t *)(*end));
+       *(uint16_t *)(response + offset) = *(uint16_t *)(*end);
+       offset += 2;
+       (*end) += 2;
+
+       *class = ntohs(*(uint16_t *)(*end));
+       *(uint16_t *)(response + offset) = *(uint16_t *)(*end);
+       offset += 2;
+       (*end) += 2;
+
+       *ttl = ntohl(*(uint32_t *)(*end));
+       *(uint32_t *)(response + offset) = *(uint32_t *)(*end);
+       offset += 4;
+       (*end) += 4;
+
+       _rdlen = ntohs(*(uint16_t *)(*end));
+       *(uint16_t *)(response + offset) = *(uint16_t *)(*end);
+       offset += 2;
+       (*end) += 2;
+
+       memcpy(response + offset, *end, _rdlen);
+       offset += _rdlen;
+       (*end) += _rdlen;
+
+       *response_size = offset;
+
+       return 0;
+}
+
+static int parse_response(unsigned char *buf, int buflen,
+                       unsigned char *question, int qlen,
+                       int *type, int *class, int *ttl,
+                       unsigned char *response, unsigned int *resp_len,
+                       int *answers)
+{
+       struct domain_hdr *hdr = (void *) buf;
+       unsigned char *ptr;
+       uint16_t qdcount = ntohs(hdr->qdcount);
+       uint16_t ancount = ntohs(hdr->ancount);
+       int err, i, qtype, qclass;
+       unsigned char *next = NULL;
+
+       if (buflen < 12)
+               return -EINVAL;
+
+       DBG("qr %d qdcount %d", hdr->qr, qdcount);
+
+       /* We currently only cache responses where question count is 1 */
+       if (hdr->qr != 1 || qdcount != 1)
+               return -EINVAL;
+
+       ptr = buf + sizeof(struct domain_hdr);
+
+       strncpy((char *)question, (char *)ptr, qlen);
+       qlen = strlen((char *)question);
+       ptr += qlen + 1;
+       qtype = ntohs(*(uint16_t *)ptr);
+
+       /* We cache only A and AAAA records */
+       if (qtype != 1 && qtype != 28)
+               return -ENOMSG;
+
+       ptr += 2;
+       qclass = ntohs(*(uint16_t *)ptr);
+       ptr += 2; /* ptr points now to answers */
+
+       err = -ENOMSG;
+       *resp_len = 0;
+       *answers = 0;
+
+       for (i = 0; i < ancount; i++) {
+               unsigned char rsp[128];
+               unsigned int rsp_len = sizeof(rsp) - 1;
+               int len = 0;
+
+               memset(rsp, 0, sizeof(rsp));
+               err = parse_rr(buf, ptr, buf + buflen, rsp, &rsp_len, &len,
+                       type, class, ttl, &next);
+               if (err != 0)
+                       return err;
+
+               if ((strncmp((char *)question, (char *)rsp, len) == 0) &&
+                               *type == qtype && *class == qclass) {
+                       memcpy(response + *resp_len, rsp, rsp_len);
+                       *resp_len += rsp_len;
+                       err = 0;
+                       (*answers)++;
+               }
+
+               ptr = next;
+               next = NULL;
+       }
+
+       return err;
+}
+
+static int cache_update(struct server_data *data, unsigned char *msg,
+                       unsigned int msg_len)
+{
+       int offset = protocol_offset(data->protocol);
+       struct cache_entry *entry;
+       unsigned char question[256];
+       unsigned char response[256];
+       int err, type = 0, class = 0, ttl = 0, answers;
+       unsigned int rsplen;
+
+       if (cache_size > MAX_CACHE_SIZE)
+               return 0;
+
+       /* Continue only if response code is 0 (=ok) */
+       if (msg[3] & 0x0f)
+               return 0;
+
+       if (offset < 0)
+               return 0;
+
+       rsplen = sizeof(response) - 1;
+       err = parse_response(msg + offset, msg_len - offset,
+                               question, sizeof(question) - 1,
+                               &type, &class, &ttl,
+                               response, &rsplen, &answers);
+       if (err < 0)
+               return 0;
+
+       if (g_hash_table_lookup(data->cache, question) == NULL) {
+               int qlen = strlen((char *)question);
+               unsigned char *ptr;
+
+               entry = g_try_new0(struct cache_entry, 1);
+               entry->cache = data->cache;
+               entry->key = (unsigned char *)g_strdup((char *)question);
+               entry->type = type;
+               entry->answers = answers;
+               entry->data_len = 12 + qlen + 1 + 2 + 2 + rsplen;
+               entry->data = ptr = g_malloc0(entry->data_len);
+
+               memcpy(ptr, msg, 12);
+               ptr += 12;
+               memcpy(ptr, question, qlen);
+               ptr += qlen + 1;
+               *(uint16_t *)ptr = htons(type);
+               ptr += 2;
+               *(uint16_t *)ptr = htons(class);
+               ptr += 2;
+               memcpy(ptr, response, rsplen);
+
+               /*
+                * Restrict the cached dns entry ttl to some sane value
+                * in order to prevent data staying in the cache too long.
+                */
+               if (ttl > MAX_CACHE_TTL)
+                       ttl = MAX_CACHE_TTL;
+
+               entry->timeout = g_timeout_add_seconds(ttl,
+                                               cache_entry_timeout,
+                                               entry);
+               g_hash_table_insert(data->cache, g_strdup((char *)entry->key),
+                               entry);
+               cache_size++;
+
+               DBG("cache %d question \"%s\" type %d ttl %d len %d",
+                       cache_size, question, type, ttl,
+                       sizeof(*entry) + entry->data_len + qlen);
+       }
+
+       return 0;
+}
+
 static int ns_resolv(struct server_data *server, struct request_data *req,
                                gpointer request, gpointer name)
 {
        GList *list;
        int sk, err;
        char *dot, *lookup = (char *) name;
+       struct cache_entry *entry;
+
+       entry = cache_check(server, request);
+       if (entry != NULL) {
+               DBG("cache hit %s", lookup);
+
+               if (req->protocol == IPPROTO_TCP) {
+                       send_cached_response(req->client_sk, entry->data,
+                                       entry->data_len, NULL, 0, IPPROTO_TCP,
+                                       req->srcid, entry->answers);
+                       return 1;
+               } else if (req->protocol == IPPROTO_UDP) {
+                       int sk;
+                       sk = g_io_channel_unix_get_fd(
+                                       req->ifdata->udp_listener_channel);
+
+                       send_cached_response(sk, entry->data,
+                               entry->data_len, &req->sa, req->sa_len,
+                               IPPROTO_UDP, req->srcid, entry->answers);
+                       return 1;
+               }
+       }
 
        sk = g_io_channel_unix_get_fd(server->channel);
 
@@ -396,7 +741,8 @@ static int ns_resolv(struct server_data *server, struct 
request_data *req,
        return 0;
 }
 
-static int forward_dns_reply(unsigned char *reply, int reply_len, int protocol)
+static int forward_dns_reply(unsigned char *reply, int reply_len, int protocol,
+                               struct server_data *data)
 {
        struct domain_hdr *hdr;
        struct request_data *req;
@@ -434,6 +780,8 @@ static int forward_dns_reply(unsigned char *reply, int 
reply_len, int protocol)
 
                memcpy(req->resp, reply, reply_len);
                req->resplen = reply_len;
+
+               cache_update(data, reply, reply_len);
        }
 
        if (hdr->rcode > 0 && req->numresp < req->numserv)
@@ -460,6 +808,24 @@ static int forward_dns_reply(unsigned char *reply, int 
reply_len, int protocol)
        return err;
 }
 
+static void cache_element_destroy(gpointer value)
+{
+       struct cache_entry *entry = value;
+
+       if (entry == NULL)
+               return;
+
+       if (entry->timeout > 0)
+               g_source_remove(entry->timeout);
+
+       g_free(entry->key);
+       g_free(entry->data);
+       g_free(entry);
+
+       if (--cache_size < 0)
+               cache_size = 0;
+}
+
 static void destroy_server(struct server_data *server)
 {
        GList *list;
@@ -488,6 +854,9 @@ static void destroy_server(struct server_data *server)
                g_free(domain);
        }
        g_free(server->interface);
+
+       g_hash_table_destroy(server->cache);
+
        g_free(server);
 }
 
@@ -496,10 +865,9 @@ static gboolean udp_server_event(GIOChannel *channel, 
GIOCondition condition,
 {
        unsigned char buf[4096];
        int sk, err, len;
+       struct server_data *data = user_data;
 
        if (condition & (G_IO_NVAL | G_IO_ERR | G_IO_HUP)) {
-               struct server_data *data = user_data;
-
                connman_error("Error with UDP server %s", data->server);
                data->watch = 0;
                return FALSE;
@@ -511,7 +879,7 @@ static gboolean udp_server_event(GIOChannel *channel, 
GIOCondition condition,
        if (len < 12)
                return TRUE;
 
-       err = forward_dns_reply(buf, len, IPPROTO_UDP);
+       err = forward_dns_reply(buf, len, IPPROTO_UDP, data);
        if (err < 0)
                return TRUE;
 
@@ -612,7 +980,16 @@ hangup:
 
                        req->timeout = g_timeout_add_seconds(30,
                                                request_timeout, req);
-                       ns_resolv(server, req, req->request, req->name);
+                       if (ns_resolv(server, req, req->request,
+                                       req->name) > 0) {
+                               /* We sent cached result so no need for timeout
+                                * handler.
+                                */
+                               if (req->timeout > 0) {
+                                       g_source_remove(req->timeout);
+                                       req->timeout = 0;
+                               }
+                       }
                }
 
        } else if (condition & G_IO_IN) {
@@ -668,7 +1045,8 @@ hangup:
                        reply->received += bytes_recv;
                }
 
-               forward_dns_reply(reply->buf, reply->received, IPPROTO_TCP);
+               forward_dns_reply(reply->buf, reply->received, IPPROTO_TCP,
+                                       server);
 
                g_free(reply);
                server->incoming_reply = NULL;
@@ -818,6 +1196,11 @@ static struct server_data *create_server(const char 
*interface,
                }
        }
 
+       data->cache = g_hash_table_new_full(g_str_hash,
+                                               g_str_equal,
+                                               g_free,
+                                               cache_element_destroy);
+
        if (protocol == IPPROTO_UDP) {
                /* Enable new servers by default */
                data->enabled = TRUE;
@@ -835,6 +1218,7 @@ static gboolean resolv(struct request_data *req,
                                gpointer request, gpointer name)
 {
        GSList *list;
+       int status;
 
        for (list = server_list; list; list = list->next) {
                struct server_data *data = list->data;
@@ -849,8 +1233,15 @@ static gboolean resolv(struct request_data *req,
                                G_IO_IN | G_IO_NVAL | G_IO_ERR | G_IO_HUP,
                                                udp_server_event, data);
 
-               if (ns_resolv(data, req, request, name) < 0)
+               status = ns_resolv(data, req, request, name);
+               if (status < 0)
                        continue;
+               else if (status > 0) {
+                       if (req->timeout > 0) {
+                               g_source_remove(req->timeout);
+                               req->timeout = 0;
+                       }
+               }
        }
 
        return TRUE;
@@ -1223,7 +1614,12 @@ static gboolean tcp_listener_event(GIOChannel *channel, 
GIOCondition condition,
                }
 
                req->timeout = g_timeout_add_seconds(30, request_timeout, req);
-               ns_resolv(server, req, buf, query);
+               if (ns_resolv(server, req, buf, query) > 0) {
+                       if (req->timeout > 0) {
+                               g_source_remove(req->timeout);
+                               req->timeout = 0;
+                       }
+               }
        }
 
        return TRUE;
-- 
1.7.1

_______________________________________________
connman mailing list
[email protected]
http://lists.connman.net/listinfo/connman

Reply via email to