This is an automated email from the ASF dual-hosted git repository.

qian pushed a commit to branch v2-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 285791d60317bb3faf0601124256a3b49bc33c46
Author: yuqian90 <[email protected]>
AuthorDate: Sat May 29 23:00:54 2021 +0800

    Fix Celery executor getting stuck randomly because of reset_signals in 
multiprocessing (#15989)
    
    Fixes #15938
    
    multiprocessing.Pool is known to often become stuck. It causes 
celery_executor to hang randomly. This happens at least on Debian, Ubuntu using 
Python 3.8.7 and Python 3.8.10. The issue is reproducible by running 
test_send_tasks_to_celery_hang in this PR several times (with db backend set to 
something other than sqlite because sqlite disables some parallelization)
    
    The issue goes away once switched to 
concurrent.futures.ProcessPoolExecutor. In python 3.6 and earlier, 
ProcessPoolExecutor has no initializer argument. Fortunately, it's not needed 
because reset_signal is no longer needed because the signal handler now checks 
if the current process is the parent.
    
    (cherry picked from commit f75dd7ae6e755dad328ba6f3fd462ade194dab25)
---
 airflow/executors/celery_executor.py    | 24 +++++----------
 airflow/jobs/scheduler_job.py           | 16 ++++++++++
 scripts/ci/docker-compose/base.yml      |  2 ++
 tests/executors/test_celery_executor.py | 52 +++++++++++++++++++++++++++++++++
 4 files changed, 78 insertions(+), 16 deletions(-)

diff --git a/airflow/executors/celery_executor.py 
b/airflow/executors/celery_executor.py
index bc321c6..553639b 100644
--- a/airflow/executors/celery_executor.py
+++ b/airflow/executors/celery_executor.py
@@ -30,7 +30,8 @@ import subprocess
 import time
 import traceback
 from collections import OrderedDict
-from multiprocessing import Pool, cpu_count
+from concurrent.futures import ProcessPoolExecutor
+from multiprocessing import cpu_count
 from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Set, 
Tuple, Union
 
 from celery import Celery, Task, states as celery_states
@@ -318,18 +319,9 @@ class CeleryExecutor(BaseExecutor):
         chunksize = self._num_tasks_per_send_process(len(task_tuples_to_send))
         num_processes = min(len(task_tuples_to_send), self._sync_parallelism)
 
-        def reset_signals():
-            # Since we are run from inside the SchedulerJob, we don't to
-            # inherit the signal handlers that we registered there.
-            import signal
-
-            signal.signal(signal.SIGINT, signal.SIG_DFL)
-            signal.signal(signal.SIGTERM, signal.SIG_DFL)
-            signal.signal(signal.SIGUSR2, signal.SIG_DFL)
-
-        with Pool(processes=num_processes, initializer=reset_signals) as 
send_pool:
-            key_and_async_results = send_pool.map(
-                send_task_to_executor, task_tuples_to_send, chunksize=chunksize
+        with ProcessPoolExecutor(max_workers=num_processes) as send_pool:
+            key_and_async_results = list(
+                send_pool.map(send_task_to_executor, task_tuples_to_send, 
chunksize=chunksize)
             )
         return key_and_async_results
 
@@ -592,11 +584,11 @@ class BulkStateFetcher(LoggingMixin):
     def _get_many_using_multiprocessing(self, async_results) -> Mapping[str, 
EventBufferValueType]:
         num_process = min(len(async_results), self._sync_parallelism)
 
-        with Pool(processes=num_process) as sync_pool:
+        with ProcessPoolExecutor(max_workers=num_process) as sync_pool:
             chunksize = max(1, math.floor(math.ceil(1.0 * len(async_results) / 
self._sync_parallelism)))
 
-            task_id_to_states_and_info = sync_pool.map(
-                fetch_celery_task_state, async_results, chunksize=chunksize
+            task_id_to_states_and_info = list(
+                sync_pool.map(fetch_celery_task_state, async_results, 
chunksize=chunksize)
             )
 
             states_and_info_by_task_id: MutableMapping[str, 
EventBufferValueType] = {}
diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index e86a6e7..cece87e 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -670,6 +670,14 @@ class DagFileProcessor(LoggingMixin):
         return len(dagbag.dags), len(dagbag.import_errors)
 
 
+def _is_parent_process():
+    """
+    Returns True if the current process is the parent process. False if the 
current process is a child
+    process started by multiprocessing.
+    """
+    return multiprocessing.current_process().name == 'MainProcess'
+
+
 class SchedulerJob(BaseJob):  # pylint: disable=too-many-instance-attributes
     """
     This SchedulerJob runs for a specific time interval and schedules the jobs
@@ -745,12 +753,20 @@ class SchedulerJob(BaseJob):  # pylint: 
disable=too-many-instance-attributes
 
     def _exit_gracefully(self, signum, frame) -> None:  # pylint: 
disable=unused-argument
         """Helper method to clean up processor_agent to avoid leaving orphan 
processes."""
+        if not _is_parent_process():
+            # Only the parent process should perform the cleanup.
+            return
+
         self.log.info("Exiting gracefully upon receiving signal %s", signum)
         if self.processor_agent:
             self.processor_agent.end()
         sys.exit(os.EX_OK)
 
     def _debug_dump(self, signum, frame):  # pylint: disable=unused-argument
+        if not _is_parent_process():
+            # Only the parent process should perform the debug dump.
+            return
+
         try:
             sig_name = signal.Signals(signum).name  # pylint: disable=no-member
         except Exception:  # pylint: disable=broad-except
diff --git a/scripts/ci/docker-compose/base.yml 
b/scripts/ci/docker-compose/base.yml
index eab6425..6b1cb4e 100644
--- a/scripts/ci/docker-compose/base.yml
+++ b/scripts/ci/docker-compose/base.yml
@@ -34,6 +34,8 @@ services:
     ports:
       - "${WEBSERVER_HOST_PORT}:8080"
       - "${FLOWER_HOST_PORT}:5555"
+    cap_add:
+      - SYS_PTRACE
 volumes:
   sqlite-db-volume:
   postgres-db-volume:
diff --git a/tests/executors/test_celery_executor.py 
b/tests/executors/test_celery_executor.py
index f454c5a..19c8a0d 100644
--- a/tests/executors/test_celery_executor.py
+++ b/tests/executors/test_celery_executor.py
@@ -18,6 +18,7 @@
 import contextlib
 import json
 import os
+import signal
 import sys
 import unittest
 from datetime import datetime, timedelta
@@ -484,3 +485,54 @@ class TestBulkStateFetcher(unittest.TestCase):
         assert [
             'DEBUG:airflow.executors.celery_executor.BulkStateFetcher:Fetched 
2 state(s) for 2 task(s)'
         ] == cm.output
+
+
+class MockTask:
+    """
+    A picklable object used to mock tasks sent to Celery. Can't use the mock 
library
+    here because it's not picklable.
+    """
+
+    def apply_async(self, *args, **kwargs):
+        return 1
+
+
+def _exit_gracefully(signum, _):
+    print(f"{os.getpid()} Exiting gracefully upon receiving signal {signum}")
+    sys.exit(signum)
+
+
[email protected]
+def register_signals():
+    """
+    Register the same signals as scheduler does to test celery_executor to 
make sure it does not
+    hang.
+    """
+    orig_sigint = orig_sigterm = orig_sigusr2 = signal.SIG_DFL
+
+    orig_sigint = signal.signal(signal.SIGINT, _exit_gracefully)
+    orig_sigterm = signal.signal(signal.SIGTERM, _exit_gracefully)
+    orig_sigusr2 = signal.signal(signal.SIGUSR2, _exit_gracefully)
+
+    yield
+
+    # Restore original signal handlers after test
+    signal.signal(signal.SIGINT, orig_sigint)
+    signal.signal(signal.SIGTERM, orig_sigterm)
+    signal.signal(signal.SIGUSR2, orig_sigusr2)
+
+
+def test_send_tasks_to_celery_hang(register_signals):  # pylint: 
disable=unused-argument
+    """
+    Test that celery_executor does not hang after many runs.
+    """
+    executor = celery_executor.CeleryExecutor()
+
+    task = MockTask()
+    task_tuples_to_send = [(None, None, None, None, task) for _ in range(26)]
+
+    for _ in range(500):
+        # This loop can hang on Linux if celery_executor does something wrong 
with
+        # multiprocessing.
+        results = executor._send_tasks_to_celery(task_tuples_to_send)
+        assert results == [(None, None, 1) for _ in task_tuples_to_send]

Reply via email to