Author: cito
Date: Sun Jan  6 12:15:51 2013
New Revision: 503

Log:
Some refactoring, avoid code duplication.
Use flag instead of pseudo event to check listener status.

Modified:
   trunk/module/TEST_PyGreSQL_classic.py
   trunk/module/pg.py

Modified: trunk/module/TEST_PyGreSQL_classic.py
==============================================================================
--- trunk/module/TEST_PyGreSQL_classic.py       Sun Jan  6 11:47:08 2013        
(r502)
+++ trunk/module/TEST_PyGreSQL_classic.py       Sun Jan  6 12:15:51 2013        
(r503)
@@ -3,6 +3,7 @@
 from __future__ import with_statement
 
 import sys
+from functools import partial
 from time import sleep
 from threading import Thread
 import unittest
@@ -227,138 +228,85 @@
         self.assertEqual(q("'", 'text'), "''''")
         self.assertEqual(q("\\", 'text'), "'\\\\'")
 
-    # note that notify can be created as part of the DB class or
-    # independently.
-
     def notify_callback(self, arg_dict):
         if arg_dict:
             arg_dict['called'] = True
         else:
             self.notify_timeout = True
 
-    def test_notify_DB(self):
-        db = opendb()
-        db2 = opendb()
-        arg_dict = {}
-        self.notify_timeout = False
-        # Listen for 'event_1'
-        pgn = db.pgnotify('event_1', self.notify_callback, arg_dict)
-        thread = Thread(None, pgn)
-        thread.start()
-        # Wait until the thread has started.
-        for n in xrange(500):
-            if 'event' in arg_dict:
-                break
-            sleep(0.01)
-        self.assertTrue('event' in arg_dict)
-        self.assertTrue(thread.isAlive())
-        # Generate notification from the other connection.
-        db2.query("notify event_1, 'payload_1'")
-        # Wait until the notification has been caught.
-        for n in xrange(500):
-            if arg_dict['event'] == 'event_1':
-                break
-            sleep(0.01)
-        self.assertEqual(arg_dict['event'], 'event_1')
-        self.assertEqual(arg_dict['extra'], 'payload_1')
-        self.assertTrue(isinstance(arg_dict['pid'], int))
-        # Check that callback has been invoked.
-        self.assertTrue(arg_dict.get('called'))
-        self.assertFalse(self.notify_timeout)
-        arg_dict['called'] = False
-        self.assertTrue(thread.isAlive())
-        # Generate stop notification
-        db2.query("notify stop_event_1, 'payload_1'")
-        # Wait until the notification has been caught.
-        for n in xrange(500):
-            if arg_dict['event'] == 'stop_event_1':
-                break
-            sleep(0.01)
-        self.assertEqual(arg_dict['event'], 'stop_event_1')
-        self.assertEqual(arg_dict['extra'], 'payload_1')
-        self.assertTrue(isinstance(arg_dict['pid'], int))
-        # Check that callback has been invoked.
-        self.assertTrue(arg_dict.get('called'))
-        self.assertFalse(self.notify_timeout)
-        thread.join(5)
-        self.assertFalse(thread.isAlive())
-
-    def test_notify_timeout_DB(self):
-        db = opendb()
-        arg_dict = {}
-        self.notify_timeout = False
-        # Listen for 'event_1'.
-        pgn = db.pgnotify('event_1', self.notify_callback, arg_dict, 0.01)
-        thread = Thread(None, pgn)
-        thread.start()
-        # Sleep long enough to time out.
-        sleep(0.02)
-        # Verify that we've indeed timed out.
-        self.assertFalse(arg_dict.get('called'))
-        self.assertTrue(self.notify_timeout)
-        self.assertFalse(thread.isAlive())
-
     def test_notify(self):
-        db = opendb()
-        db2 = opendb()
-        arg_dict = {}
-        self.notify_timeout = False
-        # Listen for 'event_1'
-        pgn = pgnotify(db, 'event_1', self.notify_callback, arg_dict)
-        thread = Thread(None, pgn)
-        thread.start()
-        # Wait until the thread has started.
-        for n in xrange(500):
-            if 'event' in arg_dict:
-                break
-            sleep(0.01)
-        self.assertTrue('event' in arg_dict)
-        self.assertTrue(thread.isAlive())
-        # Generate notification from the other connection.
-        db2.query("notify event_1, 'payload_1'")
-        # Wait until the notification has been caught.
-        for n in xrange(500):
-            if arg_dict['event'] == 'event_1':
-                break
-            sleep(0.01)
-        self.assertEqual(arg_dict['event'], 'event_1')
-        self.assertEqual(arg_dict['extra'], 'payload_1')
-        self.assertTrue(isinstance(arg_dict['pid'], int))
-        # Check that callback has been invoked.
-        self.assertTrue(arg_dict.get('called'))
-        self.assertFalse(self.notify_timeout)
-        arg_dict['called'] = False
-        self.assertTrue(thread.isAlive())
-        # Generate stop notification
-        db2.query("notify stop_event_1, 'payload_1'")
-        # Wait until the notification has been caught.
-        for n in xrange(500):
-            if arg_dict['event'] == 'stop_event_1':
-                break
-            sleep(0.01)
-        self.assertEqual(arg_dict['event'], 'stop_event_1')
-        self.assertEqual(arg_dict['extra'], 'payload_1')
-        self.assertTrue(isinstance(arg_dict['pid'], int))
-        # Check that callback has been invoked.
-        self.assertTrue(arg_dict.get('called'))
-        self.assertFalse(self.notify_timeout)
-        thread.join(5)
-        self.assertFalse(thread.isAlive())
+        for test_method in False, True:
+            db = opendb()
+            # Get function under test, can be standalone or DB method.
+            fut = db.pgnotify if test_method else partial(pgnotify, db)
+            arg_dict = dict(event=None, called=False)
+            self.notify_timeout = False
+            # Listen for 'event_1'
+            target = fut('event_1', self.notify_callback, arg_dict)
+            thread = Thread(None, target)
+            thread.start()
+            # Wait until the thread has started.
+            for n in xrange(500):
+                if target.listening:
+                    break
+                sleep(0.01)
+            self.assertTrue(target.listening)
+            self.assertTrue(thread.isAlive())
+            # Open another connection for sending notifications.
+            db2 = opendb()
+            # Generate notification from the other connection.
+            db2.query("notify event_1, 'payload_1'")
+            # Wait until the notification has been caught.
+            for n in xrange(500):
+                if arg_dict['event'] == 'event_1':
+                    break
+                sleep(0.01)
+            self.assertEqual(arg_dict['event'], 'event_1')
+            self.assertEqual(arg_dict['extra'], 'payload_1')
+            self.assertTrue(isinstance(arg_dict['pid'], int))
+            # Check that callback has been invoked.
+            self.assertTrue(arg_dict.get('called'))
+            self.assertFalse(self.notify_timeout)
+            arg_dict['called'] = False
+            self.assertTrue(thread.isAlive())
+            # Generate stop notification.
+            db2.query("notify stop_event_1, 'payload_2'")
+            db2.close()
+            # Wait until the notification has been caught.
+            for n in xrange(500):
+                if arg_dict['event'] == 'stop_event_1':
+                    break
+                sleep(0.01)
+            self.assertEqual(arg_dict['event'], 'stop_event_1')
+            self.assertEqual(arg_dict['extra'], 'payload_2')
+            self.assertTrue(isinstance(arg_dict['pid'], int))
+            # Check that callback has been invoked.
+            self.assertTrue(arg_dict.get('called'))
+            self.assertFalse(self.notify_timeout)
+            thread.join(5)
+            self.assertFalse(thread.isAlive())
+            self.assertFalse(target.listening)
+            target.close()
 
     def test_notify_timeout(self):
-        db = opendb()
-        arg_dict = {}
-        self.notify_timeout = False
-        # Listen for 'event_1'.
-        pgn = pgnotify(db, 'event_1', self.notify_callback, arg_dict, 0.01)
-        thread = Thread(None, pgn)
-        thread.start()
-        # Sleep long enough to time out.
-        sleep(0.02)
-        # Verify that we've indeed timed out.
-        self.assertFalse(arg_dict.get('called'))
-        self.assertTrue(self.notify_timeout)
-        self.assertFalse(thread.isAlive())
+        for test_method in False, True:
+            db = opendb()
+            # Get function under test, can be standalone or DB method.
+            fut = db.pgnotify if test_method else partial(pgnotify, db)
+            arg_dict = dict(event=None, called=False)
+            self.notify_timeout = False
+            # Listen for 'event_1' with timeout of 10ms
+            target = fut('event_1', self.notify_callback, arg_dict, 0.01)
+            thread = Thread(None, target)
+            thread.start()
+            # Sleep 20ms, long enough to time out.
+            sleep(0.02)
+            # Verify that we've indeed timed out.
+            self.assertFalse(arg_dict.get('called'))
+            self.assertTrue(self.notify_timeout)
+            self.assertFalse(thread.isAlive())
+            self.assertFalse(target.listening)
+            target.close()
 
 
 if __name__ == '__main__':

Modified: trunk/module/pg.py
==============================================================================
--- trunk/module/pg.py  Sun Jan  6 11:47:08 2013        (r502)
+++ trunk/module/pg.py  Sun Jan  6 12:15:51 2013        (r503)
@@ -132,10 +132,10 @@
 class pgnotify(object):
     """A PostgreSQL client-side asynchronous notification handler."""
 
-    def __init__(self, pgconn, event, callback, arg_dict=None, timeout=None):
+    def __init__(self, db, event, callback, arg_dict=None, timeout=None):
         """Initialize the notification handler.
 
-        pgconn   - PostgreSQL connection object.
+        db   - PostgreSQL connection object.
         event    - Event to LISTEN for.
         callback - Event callback.
         arg_dict - A dictionary passed as the argument to the callback.
@@ -143,10 +143,10 @@
                     fractions of seconds. If it is absent or None, the
                     callers will never time out."""
 
-        self.pgconn = pgconn
+        self.db = db
         self.event = event
-        self.start = 'start_%s' % event
         self.stop = 'stop_%s' % event
+        self.listening = False
         self.callback = callback
         if arg_dict is None:
             arg_dict = {}
@@ -154,11 +154,25 @@
         self.timeout = timeout
 
     def __del__(self):
-        try:
-            self.pgconn.query('unlisten "%s"' % self.event)
-            self.pgconn.query('unlisten "%s"' % self.stop)
-        except DatabaseError:
-            pass
+        self.close()
+
+    def close(self):
+        if self.db:
+            self.unlisten()
+            self.db.close()
+            self.db = None
+
+    def listen(self):
+        if not self.listening:
+            self.db.query('listen "%s"' % self.event)
+            self.db.query('listen "%s"' % self.stop)
+            self.listening = True
+
+    def unlisten(self):
+        if self.listening:
+            self.db.query('unlisten "%s"' % self.event)
+            self.db.query('unlisten "%s"' % self.stop)
+            self.listening = False
 
     def __call__(self):
         """Invoke the handler.
@@ -172,20 +186,17 @@
         invoked with <arg_dict>. If the NOTIFY message is stop_<event>, the
         handler UNLISTENs both <event> and stop_<event> and exits."""
 
-        self.pgconn.query('listen "%s"' % self.event)
-        self.pgconn.query('listen "%s"' % self.stop)
-        self.arg_dict['event'] = self.start
-        _ilist = [self.pgconn.fileno()]
+        self.listen()
+        _ilist = [self.db.fileno()]
 
         while True:
             ilist, _olist, _elist = select.select(_ilist, [], [], self.timeout)
             if ilist == []:  # we timed out
-                self.pgconn.query('unlisten "%s"' % self.event)
-                self.pgconn.query('unlisten "%s"' % self.stop)
+                self.unlisten()
                 self.callback(None)
                 break
             else:
-                notice = self.pgconn.getnotify()
+                notice = self.db.getnotify()
                 if notice is None:
                     continue
                 event, pid, extra = notice
@@ -195,12 +206,10 @@
                     self.arg_dict['extra'] = extra
                     self.callback(self.arg_dict)
                     if event == self.stop:
-                        self.pgconn.query('unlisten "%s"' % self.event)
-                        self.pgconn.query('unlisten "%s"' % self.stop)
+                        self.unlisten()
                         break
                 else:
-                    self.pgconn.query('unlisten "%s"' % self.event)
-                    self.pgconn.query('unlisten "%s"' % self.stop)
+                    self.unlisten()
                     raise _db_error(
                         'listening for "%s" and "%s", but notified of "%s"'
                         % (self.event, self.stop, event))
_______________________________________________
PyGreSQL mailing list
[email protected]
https://mail.vex.net/mailman/listinfo.cgi/pygresql

Reply via email to