changeset 72237804fbcd in trytond:5.0
details: https://hg.tryton.org/trytond?cmd=changeset;node=72237804fbcd
description:
        Store listener threads per process

        This ensure that if the process is forked, the child process will start 
new
        listeners for the Cache and the Bus.

        issue9832
        review329601002
        (grafted from ecc6812907828d4971a308435c64d32090bd07d4)
diffstat:

 trytond/bus.py            |  52 +++++++++++++++++++++++++---------------------
 trytond/tests/test_bus.py |  10 +++++---
 2 files changed, 34 insertions(+), 28 deletions(-)

diffs (170 lines):

diff -r 1a547d070e26 -r 72237804fbcd trytond/bus.py
--- a/trytond/bus.py    Wed Nov 11 15:52:27 2020 +0100
+++ b/trytond/bus.py    Fri Nov 27 22:19:56 2020 +0100
@@ -4,6 +4,7 @@
 import collections
 import json
 import logging
+import os
 import select
 import threading
 import time
@@ -35,7 +36,7 @@
 
     def __init__(self, timeout):
         super().__init__()
-        self._lock = threading.Lock()
+        self._lock = collections.defaultdict(threading.Lock)
         self._timeout = timeout
         self._messages = []
 
@@ -65,7 +66,7 @@
             if first_message and not found:
                 message = first_message
 
-        with self._lock:
+        with self._lock[os.getpid()]:
             del self._messages[:to_delete_index]
 
         return message.channel, message.content
@@ -74,20 +75,21 @@
 class LongPollingBus:
 
     _channel = 'bus'
-    _queues_lock = threading.Lock()
+    _queues_lock = collections.defaultdict(threading.Lock)
     _queues = collections.defaultdict(
         lambda: {'timeout': None, 'events': collections.defaultdict(list)})
     _messages = {}
 
     @classmethod
     def subscribe(cls, database, channels, last_message=None):
-        with cls._queues_lock:
-            start_listener = database not in cls._queues
-            cls._queues[database]['timeout'] = time.time() + _db_timeout
+        pid = os.getpid()
+        with cls._queues_lock[pid]:
+            start_listener = (pid, database) not in cls._queues
+            cls._queues[pid, database]['timeout'] = time.time() + _db_timeout
             if start_listener:
                 listener = threading.Thread(
                     target=cls._listen, args=(database,), daemon=True)
-                cls._queues[database]['listener'] = listener
+                cls._queues[pid, database]['listener'] = listener
                 listener.start()
 
         messages = cls._messages.get(database)
@@ -98,11 +100,12 @@
 
         event = threading.Event()
         for channel in channels:
-            if channel in cls._queues[database]['events']:
-                event_channel = cls._queues[database]['events'][channel]
+            if channel in cls._queues[pid, database]['events']:
+                event_channel = cls._queues[pid, database]['events'][channel]
             else:
-                with cls._queues_lock:
-                    event_channel = cls._queues[database]['events'][channel]
+                with cls._queues_lock[pid]:
+                    event_channel = cls._queues[pid, database][
+                        'events'][channel]
             event_channel.append(event)
 
         triggered = event.wait(_long_polling_timeout)
@@ -112,9 +115,9 @@
             response = cls.create_response(
                 *cls._messages[database].get_next(channels, last_message))
 
-        with cls._queues_lock:
+        with cls._queues_lock[pid]:
             for channel in channels:
-                events = cls._queues[database]['events'][channel]
+                events = cls._queues[pid, database]['events'][channel]
                 for e in events[:]:
                     if e.is_set():
                         events.remove(e)
@@ -139,6 +142,7 @@
 
         logger.info("listening on channel '%s'", cls._channel)
         conn = db.get_connection()
+        pid = os.getpid()
         try:
             cursor = conn.cursor()
             cursor.execute('LISTEN "%s"' % cls._channel)
@@ -147,7 +151,7 @@
             cls._messages[database] = messages = _MessageQueue(_cache_timeout)
 
             now = time.time()
-            while cls._queues[database]['timeout'] > now:
+            while cls._queues[pid, database]['timeout'] > now:
                 readable, _, _ = select.select([conn], [], [], _select_timeout)
                 if not readable:
                     continue
@@ -162,10 +166,10 @@
                     message = payload['message']
                     messages.append(channel, message)
 
-                    with cls._queues_lock:
-                        events = \
-                            cls._queues[database]['events'][channel].copy()
-                        cls._queues[database]['events'][channel].clear()
+                    with cls._queues_lock[pid]:
+                        events = cls._queues[pid, database][
+                            'events'][channel].copy()
+                        cls._queues[pid, database]['events'][channel].clear()
                     for event in events:
                         event.set()
                 now = time.time()
@@ -173,20 +177,20 @@
             logger.error('bus listener on "%s" crashed', database,
                 exc_info=True)
 
-            with cls._queues_lock:
-                del cls._queues[database]
+            with cls._queues_lock[pid]:
+                del cls._queues[pid, database]
             raise
         finally:
             db.put_connection(conn)
 
-        with cls._queues_lock:
-            if cls._queues[database]['timeout'] <= now:
-                del cls._queues[database]
+        with cls._queues_lock[pid]:
+            if cls._queues[pid, database]['timeout'] <= now:
+                del cls._queues[pid, database]
             else:
                 # A query arrived between the end of the while and here
                 listener = threading.Thread(
                     target=cls._listen, args=(database,), daemon=True)
-                cls._queues[database]['listener'] = listener
+                cls._queues[pid, database]['listener'] = listener
                 listener.start()
 
     @classmethod
diff -r 1a547d070e26 -r 72237804fbcd trytond/tests/test_bus.py
--- a/trytond/tests/test_bus.py Wed Nov 11 15:52:27 2020 +0100
+++ b/trytond/tests/test_bus.py Fri Nov 27 22:19:56 2020 +0100
@@ -1,5 +1,6 @@
 # This file is part of Tryton.  The COPYRIGHT file at the top level of
 # this repository contains the full copyright notices and license terms.
+import os
 import time
 import unittest
 from unittest.mock import patch
@@ -95,10 +96,11 @@
             setattr, bus, '_select_timeout', reset_select_timeout)
 
     def tearDown(self):
-        if DB_NAME in Bus._queues:
-            with Bus._queues_lock:
-                Bus._queues[DB_NAME]['timeout'] = 0
-                listener = Bus._queues[DB_NAME]['listener']
+        pid = os.getpid()
+        if (pid, DB_NAME) in Bus._queues:
+            with Bus._queues_lock[pid]:
+                Bus._queues[pid, DB_NAME]['timeout'] = 0
+                listener = Bus._queues[pid, DB_NAME]['listener']
             listener.join()
         Bus._messages.clear()
 

Reply via email to