changeset ecc681290782 in trytond:default
details: https://hg.tryton.org/trytond?cmd=changeset;node=ecc681290782
description:
Store listener threads per process
This ensure that if the process is forked, the child process will start
new
listeners for the Cache and the Bus.
issue9832
review329601002
diffstat:
trytond/bus.py | 52 +++++++++++++++++++++++++---------------------
trytond/cache.py | 25 +++++++++++++---------
trytond/tests/test_bus.py | 10 +++++---
3 files changed, 49 insertions(+), 38 deletions(-)
diffs (245 lines):
diff -r 64ef3c03836c -r ecc681290782 trytond/bus.py
--- a/trytond/bus.py Fri Nov 27 20:49:09 2020 +0100
+++ b/trytond/bus.py Fri Nov 27 22:19:56 2020 +0100
@@ -4,6 +4,7 @@
import collections
import json
import logging
+import os
import select
import threading
import time
@@ -44,7 +45,7 @@
def __init__(self, timeout):
super().__init__()
- self._lock = threading.Lock()
+ self._lock = collections.defaultdict(threading.Lock)
self._timeout = timeout
self._messages = []
@@ -74,7 +75,7 @@
if first_message and not found:
message = first_message
- with self._lock:
+ with self._lock[os.getpid()]:
del self._messages[:to_delete_index]
return message.channel, message.content
@@ -83,20 +84,21 @@
class LongPollingBus:
_channel = 'bus'
- _queues_lock = threading.Lock()
+ _queues_lock = collections.defaultdict(threading.Lock)
_queues = collections.defaultdict(
lambda: {'timeout': None, 'events': collections.defaultdict(list)})
_messages = {}
@classmethod
def subscribe(cls, database, channels, last_message=None):
- with cls._queues_lock:
- start_listener = database not in cls._queues
- cls._queues[database]['timeout'] = time.time() + _db_timeout
+ pid = os.getpid()
+ with cls._queues_lock[pid]:
+ start_listener = (pid, database) not in cls._queues
+ cls._queues[pid, database]['timeout'] = time.time() + _db_timeout
if start_listener:
listener = threading.Thread(
target=cls._listen, args=(database,), daemon=True)
- cls._queues[database]['listener'] = listener
+ cls._queues[pid, database]['listener'] = listener
listener.start()
messages = cls._messages.get(database)
@@ -107,11 +109,12 @@
event = threading.Event()
for channel in channels:
- if channel in cls._queues[database]['events']:
- event_channel = cls._queues[database]['events'][channel]
+ if channel in cls._queues[pid, database]['events']:
+ event_channel = cls._queues[pid, database]['events'][channel]
else:
- with cls._queues_lock:
- event_channel = cls._queues[database]['events'][channel]
+ with cls._queues_lock[pid]:
+ event_channel = cls._queues[pid, database][
+ 'events'][channel]
event_channel.append(event)
triggered = event.wait(_long_polling_timeout)
@@ -121,9 +124,9 @@
response = cls.create_response(
*cls._messages[database].get_next(channels, last_message))
- with cls._queues_lock:
+ with cls._queues_lock[pid]:
for channel in channels:
- events = cls._queues[database]['events'][channel]
+ events = cls._queues[pid, database]['events'][channel]
for e in events[:]:
if e.is_set():
events.remove(e)
@@ -147,6 +150,7 @@
logger.info("listening on channel '%s'", cls._channel)
conn = db.get_connection()
+ pid = os.getpid()
try:
cursor = conn.cursor()
cursor.execute('LISTEN "%s"' % cls._channel)
@@ -155,7 +159,7 @@
cls._messages[database] = messages = _MessageQueue(_cache_timeout)
now = time.time()
- while cls._queues[database]['timeout'] > now:
+ while cls._queues[pid, database]['timeout'] > now:
readable, _, _ = select.select([conn], [], [], _select_timeout)
if not readable:
continue
@@ -170,10 +174,10 @@
message = payload['message']
messages.append(channel, message)
- with cls._queues_lock:
- events = \
- cls._queues[database]['events'][channel].copy()
- cls._queues[database]['events'][channel].clear()
+ with cls._queues_lock[pid]:
+ events = cls._queues[pid, database][
+ 'events'][channel].copy()
+ cls._queues[pid, database]['events'][channel].clear()
for event in events:
event.set()
now = time.time()
@@ -181,20 +185,20 @@
logger.error('bus listener on "%s" crashed', database,
exc_info=True)
- with cls._queues_lock:
- del cls._queues[database]
+ with cls._queues_lock[pid]:
+ del cls._queues[pid, database]
raise
finally:
db.put_connection(conn)
- with cls._queues_lock:
- if cls._queues[database]['timeout'] <= now:
- del cls._queues[database]
+ with cls._queues_lock[pid]:
+ if cls._queues[pid, database]['timeout'] <= now:
+ del cls._queues[pid, database]
else:
# A query arrived between the end of the while and here
listener = threading.Thread(
target=cls._listen, args=(database,), daemon=True)
- cls._queues[database]['listener'] = listener
+ cls._queues[pid, database]['listener'] = listener
listener.start()
@classmethod
diff -r 64ef3c03836c -r ecc681290782 trytond/cache.py
--- a/trytond/cache.py Fri Nov 27 20:49:09 2020 +0100
+++ b/trytond/cache.py Fri Nov 27 22:19:56 2020 +0100
@@ -3,6 +3,7 @@
import datetime as dt
import json
import logging
+import os
import select
import threading
from collections import OrderedDict, defaultdict
@@ -102,7 +103,7 @@
_clean_last = datetime.now()
_default_lower = Transaction.monotonic_time()
_listener = {}
- _listener_lock = threading.Lock()
+ _listener_lock = defaultdict(threading.Lock)
_table = 'ir_cache'
_channel = _table
@@ -171,9 +172,10 @@
database = transaction.database
dbname = database.name
if not _clear_timeout and database.has_channel():
- with cls._listener_lock:
- if dbname not in cls._listener:
- cls._listener[dbname] = listener = threading.Thread(
+ pid = os.getpid()
+ with cls._listener_lock[pid]:
+ if (pid, dbname) not in cls._listener:
+ cls._listener[pid, dbname] = listener = threading.Thread(
target=cls._listen, args=(dbname,), daemon=True)
listener.start()
return
@@ -266,8 +268,9 @@
@classmethod
def drop(cls, dbname):
- with cls._listener_lock:
- listener = cls._listener.pop(dbname, None)
+ pid = os.getpid()
+ with cls._listener_lock[pid]:
+ listener = cls._listener.pop((pid, dbname), None)
if listener:
database = backend.Database(dbname)
conn = database.get_connection()
@@ -291,12 +294,14 @@
logger.info("listening on channel '%s' of '%s'", cls._channel, dbname)
conn = database.get_connection()
+ pid = os.getpid()
+ current_thread = threading.current_thread()
try:
cursor = conn.cursor()
cursor.execute('LISTEN "%s"' % cls._channel)
conn.commit()
- while cls._listener.get(dbname) == threading.current_thread():
+ while cls._listener.get((pid, dbname)) == current_thread:
readable, _, _ = select.select([conn], [], [])
if not readable:
continue
@@ -316,9 +321,9 @@
raise
finally:
database.put_connection(conn)
- with cls._listener_lock:
- if cls._listener.get(dbname) == threading.current_thread():
- del cls._listener[dbname]
+ with cls._listener_lock[pid]:
+ if cls._listener.get((pid, dbname)) == current_thread:
+ del cls._listener[pid, dbname]
if config.get('cache', 'class'):
diff -r 64ef3c03836c -r ecc681290782 trytond/tests/test_bus.py
--- a/trytond/tests/test_bus.py Fri Nov 27 20:49:09 2020 +0100
+++ b/trytond/tests/test_bus.py Fri Nov 27 22:19:56 2020 +0100
@@ -1,5 +1,6 @@
# This file is part of Tryton. The COPYRIGHT file at the top level of
# this repository contains the full copyright notices and license terms.
+import os
import time
import unittest
from unittest.mock import patch
@@ -95,10 +96,11 @@
setattr, bus, '_select_timeout', reset_select_timeout)
def tearDown(self):
- if DB_NAME in Bus._queues:
- with Bus._queues_lock:
- Bus._queues[DB_NAME]['timeout'] = 0
- listener = Bus._queues[DB_NAME]['listener']
+ pid = os.getpid()
+ if (pid, DB_NAME) in Bus._queues:
+ with Bus._queues_lock[pid]:
+ Bus._queues[pid, DB_NAME]['timeout'] = 0
+ listener = Bus._queues[pid, DB_NAME]['listener']
listener.join()
Bus._messages.clear()