changeset 642bbe20ede1 in trytond:5.4
details: https://hg.tryton.org/trytond?cmd=changeset;node=642bbe20ede1
description:
Store listener threads per process
This ensure that if the process is forked, the child process will start
new
listeners for the Cache and the Bus.
issue9832
review329601002
(grafted from ecc6812907828d4971a308435c64d32090bd07d4)
diffstat:
trytond/bus.py | 52 +++++++++++++++++++++++++---------------------
trytond/cache.py | 25 +++++++++++++---------
trytond/tests/test_bus.py | 10 +++++---
3 files changed, 49 insertions(+), 38 deletions(-)
diffs (245 lines):
diff -r 73aa96cd9f8a -r 642bbe20ede1 trytond/bus.py
--- a/trytond/bus.py Wed Nov 11 15:52:04 2020 +0100
+++ b/trytond/bus.py Fri Nov 27 22:19:56 2020 +0100
@@ -4,6 +4,7 @@
import collections
import json
import logging
+import os
import select
import threading
import time
@@ -43,7 +44,7 @@
def __init__(self, timeout):
super().__init__()
- self._lock = threading.Lock()
+ self._lock = collections.defaultdict(threading.Lock)
self._timeout = timeout
self._messages = []
@@ -73,7 +74,7 @@
if first_message and not found:
message = first_message
- with self._lock:
+ with self._lock[os.getpid()]:
del self._messages[:to_delete_index]
return message.channel, message.content
@@ -82,20 +83,21 @@
class LongPollingBus:
_channel = 'bus'
- _queues_lock = threading.Lock()
+ _queues_lock = collections.defaultdict(threading.Lock)
_queues = collections.defaultdict(
lambda: {'timeout': None, 'events': collections.defaultdict(list)})
_messages = {}
@classmethod
def subscribe(cls, database, channels, last_message=None):
- with cls._queues_lock:
- start_listener = database not in cls._queues
- cls._queues[database]['timeout'] = time.time() + _db_timeout
+ pid = os.getpid()
+ with cls._queues_lock[pid]:
+ start_listener = (pid, database) not in cls._queues
+ cls._queues[pid, database]['timeout'] = time.time() + _db_timeout
if start_listener:
listener = threading.Thread(
target=cls._listen, args=(database,), daemon=True)
- cls._queues[database]['listener'] = listener
+ cls._queues[pid, database]['listener'] = listener
listener.start()
messages = cls._messages.get(database)
@@ -106,11 +108,12 @@
event = threading.Event()
for channel in channels:
- if channel in cls._queues[database]['events']:
- event_channel = cls._queues[database]['events'][channel]
+ if channel in cls._queues[pid, database]['events']:
+ event_channel = cls._queues[pid, database]['events'][channel]
else:
- with cls._queues_lock:
- event_channel = cls._queues[database]['events'][channel]
+ with cls._queues_lock[pid]:
+ event_channel = cls._queues[pid, database][
+ 'events'][channel]
event_channel.append(event)
triggered = event.wait(_long_polling_timeout)
@@ -120,9 +123,9 @@
response = cls.create_response(
*cls._messages[database].get_next(channels, last_message))
- with cls._queues_lock:
+ with cls._queues_lock[pid]:
for channel in channels:
- events = cls._queues[database]['events'][channel]
+ events = cls._queues[pid, database]['events'][channel]
for e in events[:]:
if e.is_set():
events.remove(e)
@@ -147,6 +150,7 @@
logger.info("listening on channel '%s'", cls._channel)
conn = db.get_connection()
+ pid = os.getpid()
try:
cursor = conn.cursor()
cursor.execute('LISTEN "%s"' % cls._channel)
@@ -155,7 +159,7 @@
cls._messages[database] = messages = _MessageQueue(_cache_timeout)
now = time.time()
- while cls._queues[database]['timeout'] > now:
+ while cls._queues[pid, database]['timeout'] > now:
readable, _, _ = select.select([conn], [], [], _select_timeout)
if not readable:
continue
@@ -170,10 +174,10 @@
message = payload['message']
messages.append(channel, message)
- with cls._queues_lock:
- events = \
- cls._queues[database]['events'][channel].copy()
- cls._queues[database]['events'][channel].clear()
+ with cls._queues_lock[pid]:
+ events = cls._queues[pid, database][
+ 'events'][channel].copy()
+ cls._queues[pid, database]['events'][channel].clear()
for event in events:
event.set()
now = time.time()
@@ -181,20 +185,20 @@
logger.error('bus listener on "%s" crashed', database,
exc_info=True)
- with cls._queues_lock:
- del cls._queues[database]
+ with cls._queues_lock[pid]:
+ del cls._queues[pid, database]
raise
finally:
db.put_connection(conn)
- with cls._queues_lock:
- if cls._queues[database]['timeout'] <= now:
- del cls._queues[database]
+ with cls._queues_lock[pid]:
+ if cls._queues[pid, database]['timeout'] <= now:
+ del cls._queues[pid, database]
else:
# A query arrived between the end of the while and here
listener = threading.Thread(
target=cls._listen, args=(database,), daemon=True)
- cls._queues[database]['listener'] = listener
+ cls._queues[pid, database]['listener'] = listener
listener.start()
@classmethod
diff -r 73aa96cd9f8a -r 642bbe20ede1 trytond/cache.py
--- a/trytond/cache.py Wed Nov 11 15:52:04 2020 +0100
+++ b/trytond/cache.py Fri Nov 27 22:19:56 2020 +0100
@@ -3,6 +3,7 @@
import datetime as dt
import json
import logging
+import os
import select
import threading
from collections import OrderedDict, defaultdict
@@ -99,7 +100,7 @@
_clean_last = datetime.now()
_default_lower = Transaction.monotonic_time()
_listener = {}
- _listener_lock = threading.Lock()
+ _listener_lock = defaultdict(threading.Lock)
_table = 'ir_cache'
_channel = _table
@@ -168,9 +169,10 @@
database = transaction.database
dbname = database.name
if not _clear_timeout and database.has_channel():
- with cls._listener_lock:
- if dbname not in cls._listener:
- cls._listener[dbname] = listener = threading.Thread(
+ pid = os.getpid()
+ with cls._listener_lock[pid]:
+ if (pid, dbname) not in cls._listener:
+ cls._listener[pid, dbname] = listener = threading.Thread(
target=cls._listen, args=(dbname,), daemon=True)
listener.start()
return
@@ -259,8 +261,9 @@
@classmethod
def drop(cls, dbname):
- with cls._listener_lock:
- listener = cls._listener.pop(dbname, None)
+ pid = os.getpid()
+ with cls._listener_lock[pid]:
+ listener = cls._listener.pop((pid, dbname), None)
if listener:
Database = backend.get('Database')
database = Database(dbname)
@@ -286,12 +289,14 @@
logger.info("listening on channel '%s' of '%s'", cls._channel, dbname)
conn = database.get_connection()
+ pid = os.getpid()
+ current_thread = threading.current_thread()
try:
cursor = conn.cursor()
cursor.execute('LISTEN "%s"' % cls._channel)
conn.commit()
- while cls._listener.get(dbname) == threading.current_thread():
+ while cls._listener.get((pid, dbname)) == current_thread:
readable, _, _ = select.select([conn], [], [])
if not readable:
continue
@@ -310,9 +315,9 @@
raise
finally:
database.put_connection(conn)
- with cls._listener_lock:
- if cls._listener.get(dbname) == threading.current_thread():
- del cls._listener[dbname]
+ with cls._listener_lock[pid]:
+ if cls._listener.get((pid, dbname)) == current_thread:
+ del cls._listener[pid, dbname]
if config.get('cache', 'class'):
diff -r 73aa96cd9f8a -r 642bbe20ede1 trytond/tests/test_bus.py
--- a/trytond/tests/test_bus.py Wed Nov 11 15:52:04 2020 +0100
+++ b/trytond/tests/test_bus.py Fri Nov 27 22:19:56 2020 +0100
@@ -1,5 +1,6 @@
# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
+import os
import time
import unittest
from unittest.mock import patch
@@ -95,10 +96,11 @@
setattr, bus, '_select_timeout', reset_select_timeout)
def tearDown(self):
- if DB_NAME in Bus._queues:
- with Bus._queues_lock:
- Bus._queues[DB_NAME]['timeout'] = 0
- listener = Bus._queues[DB_NAME]['listener']
+ pid = os.getpid()
+ if (pid, DB_NAME) in Bus._queues:
+ with Bus._queues_lock[pid]:
+ Bus._queues[pid, DB_NAME]['timeout'] = 0
+ listener = Bus._queues[pid, DB_NAME]['listener']
listener.join()
Bus._messages.clear()