Hello,

Attached is a patch that makes dnspython use poll() (via
select.poll()) when supported on the platform. Else it falls back to
regular select. Unit tests tests both if they are available, else only
select.

Motivation: Use of select() means that you hit a limit on the number
of file descriptors you can use. However, due to how the select()
interface works it also implies a limit to the maximum numeric value
of a file descriptor, even if you are only interested in a single file
descriptor. See the acrobatics that Python does in selectmodule.c for
details.

So, in order to avoid failing when file descriptor count for the
process exceeds the limit, use poll() instead.

Thoughts?

-- 
/ Peter Schuller aka scode
diff --git a/dns/query.py b/dns/query.py
index c023b14..2a3013b 100644
--- a/dns/query.py
+++ b/dns/query.py
@@ -45,7 +45,59 @@ def _compute_expiration(timeout):
     else:
         return time.time() + timeout
 
-def _wait_for(ir, iw, ix, expiration):
+def _poll_for(fd, readable, writable, error, timeout):
+    """
+    @param fd: File descriptor (int).
+    @param readable: Whether to wait for readability (bool).
+    @param writable: Whether to wait for writability (bool).
+    @param expiration: Deadline timeout (expiration time, in seconds (float)).
+
+    @return True on success, False on timeout
+    """
+    event_mask = 0
+    if readable:
+        event_mask |= select.POLLIN
+    if writable:
+        event_mask |= select.POLLOUT
+    if error:
+        event_mask |= select.POLLERR
+
+    pollable = select.poll()
+    pollable.register(fd, event_mask)
+
+    if timeout:
+        event_list = pollable.poll(int(timeout * 1000))
+    else:
+        event_list = pollable.poll()
+
+    return bool(event_list)
+
+def _select_for(fd, readable, writable, error, timeout):
+    """
+    @param fd: File descriptor (int).
+    @param readable: Whether to wait for readability (bool).
+    @param writable: Whether to wait for writability (bool).
+    @param expiration: Deadline timeout (expiration time, in seconds (float)).
+
+    @return True on success, False on timeout
+    """
+    rset, wset, xset = [], [], []
+
+    if readable:
+        rset = [fd]
+    if writable:
+        wset = [fd]
+    if error:
+        xset = [fd]
+
+    if timeout is None:
+        (rcount, wcount, xcount) = select.select(rset, wset, xset)
+    else:
+        (rcount, wcount, xcount) = select.select(rset, wset, xset, timeout)
+
+    return bool((rcount or wcount or xcount))
+
+def _wait_for(readable, writable, error, expiration):
     done = False
     while not done:
         if expiration is None:
@@ -55,22 +107,34 @@ def _wait_for(ir, iw, ix, expiration):
             if timeout <= 0.0:
                 raise dns.exception.Timeout
         try:
-            if timeout is None:
-                (r, w, x) = select.select(ir, iw, ix)
-            else:
-                (r, w, x) = select.select(ir, iw, ix, timeout)
+            if not _wait_fd(readable, writable, error, timeout):
+                raise dns.exception.Timeout
         except select.error, e:
             if e.args[0] != errno.EINTR:
                 raise e
         done = True
-        if len(r) == 0 and len(w) == 0 and len(x) == 0:
-            raise dns.exception.Timeout
+
+def _set_polling_backend(fn):
+    """
+    Internal API. Do not use.
+    """
+    global _wait_for
+
+    _wait_for = fn
+
+if hasattr(select, 'poll'):
+    # Prefer poll() on platforms that support it because it has no
+    # limits on the maximum value of a file descriptor (plus it will
+    # be more efficient for high values).
+    _wait_for = _poll_for
+else:
+    _wait_for = _select_for
 
 def _wait_for_readable(s, expiration):
-    _wait_for([s], [], [s], expiration)
+    _wait_for(s, True, False, True, expiration)
 
 def _wait_for_writable(s, expiration):
-    _wait_for([], [s], [s], expiration)
+    _wait_for(s, False, True, True, expiration)
 
 def _addresses_equal(af, a1, a2):
     # Convert the first value of the tuple, which is a textual format
diff --git a/tests/resolver.py b/tests/resolver.py
index 4cacbdc..843f116 100644
--- a/tests/resolver.py
+++ b/tests/resolver.py
@@ -14,6 +14,7 @@
 # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 
 import cStringIO
+import select
 import sys
 import time
 import unittest
@@ -46,7 +47,7 @@ example. 1 IN A 10.0.0.1
 ;ADDITIONAL
 """
 
-class ResolverTestCase(unittest.TestCase):
+class BaseResolverTests(object):
 
     if sys.platform != 'win32':
         def testRead(self):
@@ -101,5 +102,26 @@ class ResolverTestCase(unittest.TestCase):
             zname = dns.resolver.zone_for_name(name)
         self.failUnlessRaises(dns.resolver.NotAbsolute, bad)
 
+class PollingMonkeyPatchMixin(object):
+    def setUp(self):
+        self.__native_polling_backend = dns.query._wait_for
+        dns.query._set_polling_backend(self.polling_backend())
+
+        unittest.TestCase.setUp(self)
+
+    def tearDown(self):
+        dns.query._set_polling_backend(self.__native_polling_backend)
+
+        unittest.TestCase.tearDown(self)
+
+class SelectResolverTestCase(PollingMonkeyPatchMixin, BaseResolverTests, unittest.TestCase):
+    def polling_backend(self):
+        return dns.query._select_for
+
+if hasattr(select, 'poll'):
+    class PollResolverTestCase(PollingMonkeyPatchMixin, BaseResolverTests, unittest.TestCase):
+        def polling_backend(self):
+            return dns.query._poll_for
+
 if __name__ == '__main__':
     unittest.main()
_______________________________________________
dnspython-dev mailing list
dnspython-dev@howl.play-bow.org
http://howl.play-bow.org/mailman/listinfo.cgi/dnspython-dev

Reply via email to