changeset ecc681290782 in trytond:default
details: https://hg.tryton.org/trytond?cmd=changeset;node=ecc681290782
description:
        Store listener threads per process

        This ensure that if the process is forked, the child process will start 
new
        listeners for the Cache and the Bus.

        issue9832
        review329601002
diffstat:

 trytond/bus.py            |  52 +++++++++++++++++++++++++---------------------
 trytond/cache.py          |  25 +++++++++++++---------
 trytond/tests/test_bus.py |  10 +++++---
 3 files changed, 49 insertions(+), 38 deletions(-)

diffs (245 lines):

diff -r 64ef3c03836c -r ecc681290782 trytond/bus.py
--- a/trytond/bus.py    Fri Nov 27 20:49:09 2020 +0100
+++ b/trytond/bus.py    Fri Nov 27 22:19:56 2020 +0100
@@ -4,6 +4,7 @@
 import collections
 import json
 import logging
+import os
 import select
 import threading
 import time
@@ -44,7 +45,7 @@
 
     def __init__(self, timeout):
         super().__init__()
-        self._lock = threading.Lock()
+        self._lock = collections.defaultdict(threading.Lock)
         self._timeout = timeout
         self._messages = []
 
@@ -74,7 +75,7 @@
             if first_message and not found:
                 message = first_message
 
-        with self._lock:
+        with self._lock[os.getpid()]:
             del self._messages[:to_delete_index]
 
         return message.channel, message.content
@@ -83,20 +84,21 @@
 class LongPollingBus:
 
     _channel = 'bus'
-    _queues_lock = threading.Lock()
+    _queues_lock = collections.defaultdict(threading.Lock)
     _queues = collections.defaultdict(
         lambda: {'timeout': None, 'events': collections.defaultdict(list)})
     _messages = {}
 
     @classmethod
     def subscribe(cls, database, channels, last_message=None):
-        with cls._queues_lock:
-            start_listener = database not in cls._queues
-            cls._queues[database]['timeout'] = time.time() + _db_timeout
+        pid = os.getpid()
+        with cls._queues_lock[pid]:
+            start_listener = (pid, database) not in cls._queues
+            cls._queues[pid, database]['timeout'] = time.time() + _db_timeout
             if start_listener:
                 listener = threading.Thread(
                     target=cls._listen, args=(database,), daemon=True)
-                cls._queues[database]['listener'] = listener
+                cls._queues[pid, database]['listener'] = listener
                 listener.start()
 
         messages = cls._messages.get(database)
@@ -107,11 +109,12 @@
 
         event = threading.Event()
         for channel in channels:
-            if channel in cls._queues[database]['events']:
-                event_channel = cls._queues[database]['events'][channel]
+            if channel in cls._queues[pid, database]['events']:
+                event_channel = cls._queues[pid, database]['events'][channel]
             else:
-                with cls._queues_lock:
-                    event_channel = cls._queues[database]['events'][channel]
+                with cls._queues_lock[pid]:
+                    event_channel = cls._queues[pid, database][
+                        'events'][channel]
             event_channel.append(event)
 
         triggered = event.wait(_long_polling_timeout)
@@ -121,9 +124,9 @@
             response = cls.create_response(
                 *cls._messages[database].get_next(channels, last_message))
 
-        with cls._queues_lock:
+        with cls._queues_lock[pid]:
             for channel in channels:
-                events = cls._queues[database]['events'][channel]
+                events = cls._queues[pid, database]['events'][channel]
                 for e in events[:]:
                     if e.is_set():
                         events.remove(e)
@@ -147,6 +150,7 @@
 
         logger.info("listening on channel '%s'", cls._channel)
         conn = db.get_connection()
+        pid = os.getpid()
         try:
             cursor = conn.cursor()
             cursor.execute('LISTEN "%s"' % cls._channel)
@@ -155,7 +159,7 @@
             cls._messages[database] = messages = _MessageQueue(_cache_timeout)
 
             now = time.time()
-            while cls._queues[database]['timeout'] > now:
+            while cls._queues[pid, database]['timeout'] > now:
                 readable, _, _ = select.select([conn], [], [], _select_timeout)
                 if not readable:
                     continue
@@ -170,10 +174,10 @@
                     message = payload['message']
                     messages.append(channel, message)
 
-                    with cls._queues_lock:
-                        events = \
-                            cls._queues[database]['events'][channel].copy()
-                        cls._queues[database]['events'][channel].clear()
+                    with cls._queues_lock[pid]:
+                        events = cls._queues[pid, database][
+                            'events'][channel].copy()
+                        cls._queues[pid, database]['events'][channel].clear()
                     for event in events:
                         event.set()
                 now = time.time()
@@ -181,20 +185,20 @@
             logger.error('bus listener on "%s" crashed', database,
                 exc_info=True)
 
-            with cls._queues_lock:
-                del cls._queues[database]
+            with cls._queues_lock[pid]:
+                del cls._queues[pid, database]
             raise
         finally:
             db.put_connection(conn)
 
-        with cls._queues_lock:
-            if cls._queues[database]['timeout'] <= now:
-                del cls._queues[database]
+        with cls._queues_lock[pid]:
+            if cls._queues[pid, database]['timeout'] <= now:
+                del cls._queues[pid, database]
             else:
                 # A query arrived between the end of the while and here
                 listener = threading.Thread(
                     target=cls._listen, args=(database,), daemon=True)
-                cls._queues[database]['listener'] = listener
+                cls._queues[pid, database]['listener'] = listener
                 listener.start()
 
     @classmethod
diff -r 64ef3c03836c -r ecc681290782 trytond/cache.py
--- a/trytond/cache.py  Fri Nov 27 20:49:09 2020 +0100
+++ b/trytond/cache.py  Fri Nov 27 22:19:56 2020 +0100
@@ -3,6 +3,7 @@
 import datetime as dt
 import json
 import logging
+import os
 import select
 import threading
 from collections import OrderedDict, defaultdict
@@ -102,7 +103,7 @@
     _clean_last = datetime.now()
     _default_lower = Transaction.monotonic_time()
     _listener = {}
-    _listener_lock = threading.Lock()
+    _listener_lock = defaultdict(threading.Lock)
     _table = 'ir_cache'
     _channel = _table
 
@@ -171,9 +172,10 @@
         database = transaction.database
         dbname = database.name
         if not _clear_timeout and database.has_channel():
-            with cls._listener_lock:
-                if dbname not in cls._listener:
-                    cls._listener[dbname] = listener = threading.Thread(
+            pid = os.getpid()
+            with cls._listener_lock[pid]:
+                if (pid, dbname) not in cls._listener:
+                    cls._listener[pid, dbname] = listener = threading.Thread(
                         target=cls._listen, args=(dbname,), daemon=True)
                     listener.start()
             return
@@ -266,8 +268,9 @@
 
     @classmethod
     def drop(cls, dbname):
-        with cls._listener_lock:
-            listener = cls._listener.pop(dbname, None)
+        pid = os.getpid()
+        with cls._listener_lock[pid]:
+            listener = cls._listener.pop((pid, dbname), None)
         if listener:
             database = backend.Database(dbname)
             conn = database.get_connection()
@@ -291,12 +294,14 @@
 
         logger.info("listening on channel '%s' of '%s'", cls._channel, dbname)
         conn = database.get_connection()
+        pid = os.getpid()
+        current_thread = threading.current_thread()
         try:
             cursor = conn.cursor()
             cursor.execute('LISTEN "%s"' % cls._channel)
             conn.commit()
 
-            while cls._listener.get(dbname) == threading.current_thread():
+            while cls._listener.get((pid, dbname)) == current_thread:
                 readable, _, _ = select.select([conn], [], [])
                 if not readable:
                     continue
@@ -316,9 +321,9 @@
             raise
         finally:
             database.put_connection(conn)
-            with cls._listener_lock:
-                if cls._listener.get(dbname) == threading.current_thread():
-                    del cls._listener[dbname]
+            with cls._listener_lock[pid]:
+                if cls._listener.get((pid, dbname)) == current_thread:
+                    del cls._listener[pid, dbname]
 
 
 if config.get('cache', 'class'):
diff -r 64ef3c03836c -r ecc681290782 trytond/tests/test_bus.py
--- a/trytond/tests/test_bus.py Fri Nov 27 20:49:09 2020 +0100
+++ b/trytond/tests/test_bus.py Fri Nov 27 22:19:56 2020 +0100
@@ -1,5 +1,6 @@
 # This file is part of Tryton.  The COPYRIGHT file at the top level of
 # this repository contains the full copyright notices and license terms.
+import os
 import time
 import unittest
 from unittest.mock import patch
@@ -95,10 +96,11 @@
             setattr, bus, '_select_timeout', reset_select_timeout)
 
     def tearDown(self):
-        if DB_NAME in Bus._queues:
-            with Bus._queues_lock:
-                Bus._queues[DB_NAME]['timeout'] = 0
-                listener = Bus._queues[DB_NAME]['listener']
+        pid = os.getpid()
+        if (pid, DB_NAME) in Bus._queues:
+            with Bus._queues_lock[pid]:
+                Bus._queues[pid, DB_NAME]['timeout'] = 0
+                listener = Bus._queues[pid, DB_NAME]['listener']
             listener.join()
         Bus._messages.clear()
 

Reply via email to