changeset 642bbe20ede1 in trytond:5.4
details: https://hg.tryton.org/trytond?cmd=changeset;node=642bbe20ede1
description:
        Store listener threads per process

        This ensure that if the process is forked, the child process will start 
new
        listeners for the Cache and the Bus.

        issue9832
        review329601002
        (grafted from ecc6812907828d4971a308435c64d32090bd07d4)
diffstat:

 trytond/bus.py            |  52 +++++++++++++++++++++++++---------------------
 trytond/cache.py          |  25 +++++++++++++---------
 trytond/tests/test_bus.py |  10 +++++---
 3 files changed, 49 insertions(+), 38 deletions(-)

diffs (245 lines):

diff -r 73aa96cd9f8a -r 642bbe20ede1 trytond/bus.py
--- a/trytond/bus.py    Wed Nov 11 15:52:04 2020 +0100
+++ b/trytond/bus.py    Fri Nov 27 22:19:56 2020 +0100
@@ -4,6 +4,7 @@
 import collections
 import json
 import logging
+import os
 import select
 import threading
 import time
@@ -43,7 +44,7 @@
 
     def __init__(self, timeout):
         super().__init__()
-        self._lock = threading.Lock()
+        self._lock = collections.defaultdict(threading.Lock)
         self._timeout = timeout
         self._messages = []
 
@@ -73,7 +74,7 @@
             if first_message and not found:
                 message = first_message
 
-        with self._lock:
+        with self._lock[os.getpid()]:
             del self._messages[:to_delete_index]
 
         return message.channel, message.content
@@ -82,20 +83,21 @@
 class LongPollingBus:
 
     _channel = 'bus'
-    _queues_lock = threading.Lock()
+    _queues_lock = collections.defaultdict(threading.Lock)
     _queues = collections.defaultdict(
         lambda: {'timeout': None, 'events': collections.defaultdict(list)})
     _messages = {}
 
     @classmethod
     def subscribe(cls, database, channels, last_message=None):
-        with cls._queues_lock:
-            start_listener = database not in cls._queues
-            cls._queues[database]['timeout'] = time.time() + _db_timeout
+        pid = os.getpid()
+        with cls._queues_lock[pid]:
+            start_listener = (pid, database) not in cls._queues
+            cls._queues[pid, database]['timeout'] = time.time() + _db_timeout
             if start_listener:
                 listener = threading.Thread(
                     target=cls._listen, args=(database,), daemon=True)
-                cls._queues[database]['listener'] = listener
+                cls._queues[pid, database]['listener'] = listener
                 listener.start()
 
         messages = cls._messages.get(database)
@@ -106,11 +108,12 @@
 
         event = threading.Event()
         for channel in channels:
-            if channel in cls._queues[database]['events']:
-                event_channel = cls._queues[database]['events'][channel]
+            if channel in cls._queues[pid, database]['events']:
+                event_channel = cls._queues[pid, database]['events'][channel]
             else:
-                with cls._queues_lock:
-                    event_channel = cls._queues[database]['events'][channel]
+                with cls._queues_lock[pid]:
+                    event_channel = cls._queues[pid, database][
+                        'events'][channel]
             event_channel.append(event)
 
         triggered = event.wait(_long_polling_timeout)
@@ -120,9 +123,9 @@
             response = cls.create_response(
                 *cls._messages[database].get_next(channels, last_message))
 
-        with cls._queues_lock:
+        with cls._queues_lock[pid]:
             for channel in channels:
-                events = cls._queues[database]['events'][channel]
+                events = cls._queues[pid, database]['events'][channel]
                 for e in events[:]:
                     if e.is_set():
                         events.remove(e)
@@ -147,6 +150,7 @@
 
         logger.info("listening on channel '%s'", cls._channel)
         conn = db.get_connection()
+        pid = os.getpid()
         try:
             cursor = conn.cursor()
             cursor.execute('LISTEN "%s"' % cls._channel)
@@ -155,7 +159,7 @@
             cls._messages[database] = messages = _MessageQueue(_cache_timeout)
 
             now = time.time()
-            while cls._queues[database]['timeout'] > now:
+            while cls._queues[pid, database]['timeout'] > now:
                 readable, _, _ = select.select([conn], [], [], _select_timeout)
                 if not readable:
                     continue
@@ -170,10 +174,10 @@
                     message = payload['message']
                     messages.append(channel, message)
 
-                    with cls._queues_lock:
-                        events = \
-                            cls._queues[database]['events'][channel].copy()
-                        cls._queues[database]['events'][channel].clear()
+                    with cls._queues_lock[pid]:
+                        events = cls._queues[pid, database][
+                            'events'][channel].copy()
+                        cls._queues[pid, database]['events'][channel].clear()
                     for event in events:
                         event.set()
                 now = time.time()
@@ -181,20 +185,20 @@
             logger.error('bus listener on "%s" crashed', database,
                 exc_info=True)
 
-            with cls._queues_lock:
-                del cls._queues[database]
+            with cls._queues_lock[pid]:
+                del cls._queues[pid, database]
             raise
         finally:
             db.put_connection(conn)
 
-        with cls._queues_lock:
-            if cls._queues[database]['timeout'] <= now:
-                del cls._queues[database]
+        with cls._queues_lock[pid]:
+            if cls._queues[pid, database]['timeout'] <= now:
+                del cls._queues[pid, database]
             else:
                 # A query arrived between the end of the while and here
                 listener = threading.Thread(
                     target=cls._listen, args=(database,), daemon=True)
-                cls._queues[database]['listener'] = listener
+                cls._queues[pid, database]['listener'] = listener
                 listener.start()
 
     @classmethod
diff -r 73aa96cd9f8a -r 642bbe20ede1 trytond/cache.py
--- a/trytond/cache.py  Wed Nov 11 15:52:04 2020 +0100
+++ b/trytond/cache.py  Fri Nov 27 22:19:56 2020 +0100
@@ -3,6 +3,7 @@
 import datetime as dt
 import json
 import logging
+import os
 import select
 import threading
 from collections import OrderedDict, defaultdict
@@ -99,7 +100,7 @@
     _clean_last = datetime.now()
     _default_lower = Transaction.monotonic_time()
     _listener = {}
-    _listener_lock = threading.Lock()
+    _listener_lock = defaultdict(threading.Lock)
     _table = 'ir_cache'
     _channel = _table
 
@@ -168,9 +169,10 @@
         database = transaction.database
         dbname = database.name
         if not _clear_timeout and database.has_channel():
-            with cls._listener_lock:
-                if dbname not in cls._listener:
-                    cls._listener[dbname] = listener = threading.Thread(
+            pid = os.getpid()
+            with cls._listener_lock[pid]:
+                if (pid, dbname) not in cls._listener:
+                    cls._listener[pid, dbname] = listener = threading.Thread(
                         target=cls._listen, args=(dbname,), daemon=True)
                     listener.start()
             return
@@ -259,8 +261,9 @@
 
     @classmethod
     def drop(cls, dbname):
-        with cls._listener_lock:
-            listener = cls._listener.pop(dbname, None)
+        pid = os.getpid()
+        with cls._listener_lock[pid]:
+            listener = cls._listener.pop((pid, dbname), None)
         if listener:
             Database = backend.get('Database')
             database = Database(dbname)
@@ -286,12 +289,14 @@
 
         logger.info("listening on channel '%s' of '%s'", cls._channel, dbname)
         conn = database.get_connection()
+        pid = os.getpid()
+        current_thread = threading.current_thread()
         try:
             cursor = conn.cursor()
             cursor.execute('LISTEN "%s"' % cls._channel)
             conn.commit()
 
-            while cls._listener.get(dbname) == threading.current_thread():
+            while cls._listener.get((pid, dbname)) == current_thread:
                 readable, _, _ = select.select([conn], [], [])
                 if not readable:
                     continue
@@ -310,9 +315,9 @@
             raise
         finally:
             database.put_connection(conn)
-            with cls._listener_lock:
-                if cls._listener.get(dbname) == threading.current_thread():
-                    del cls._listener[dbname]
+            with cls._listener_lock[pid]:
+                if cls._listener.get((pid, dbname)) == current_thread:
+                    del cls._listener[pid, dbname]
 
 
 if config.get('cache', 'class'):
diff -r 73aa96cd9f8a -r 642bbe20ede1 trytond/tests/test_bus.py
--- a/trytond/tests/test_bus.py Wed Nov 11 15:52:04 2020 +0100
+++ b/trytond/tests/test_bus.py Fri Nov 27 22:19:56 2020 +0100
@@ -1,5 +1,6 @@
 # This file is part of Tryton.  The COPYRIGHT file at the top level of
 # this repository contains the full copyright notices and license terms.
+import os
 import time
 import unittest
 from unittest.mock import patch
@@ -95,10 +96,11 @@
             setattr, bus, '_select_timeout', reset_select_timeout)
 
     def tearDown(self):
-        if DB_NAME in Bus._queues:
-            with Bus._queues_lock:
-                Bus._queues[DB_NAME]['timeout'] = 0
-                listener = Bus._queues[DB_NAME]['listener']
+        pid = os.getpid()
+        if (pid, DB_NAME) in Bus._queues:
+            with Bus._queues_lock[pid]:
+                Bus._queues[pid, DB_NAME]['timeout'] = 0
+                listener = Bus._queues[pid, DB_NAME]['listener']
             listener.join()
         Bus._messages.clear()
 

Reply via email to