This is an automated email from the ASF dual-hosted git repository.

kamilbregula pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 5f3774a  [AIRFLOW-6921] Fetch celery states in bulk (#7542)
5f3774a is described below

commit 5f3774aebdce8fe7bdcc478ea0b05d95c37c9440
Author: Kamil BreguĊ‚a <[email protected]>
AuthorDate: Mon May 11 10:44:14 2020 +0200

    [AIRFLOW-6921] Fetch celery states in bulk (#7542)
---
 airflow/executors/celery_executor.py    | 225 +++++++++++++++++++-------------
 tests/executors/test_celery_executor.py | 176 ++++++++++++++++++-------
 2 files changed, 268 insertions(+), 133 deletions(-)

diff --git a/airflow/executors/celery_executor.py 
b/airflow/executors/celery_executor.py
index 40f173b..98fff4b 100644
--- a/airflow/executors/celery_executor.py
+++ b/airflow/executors/celery_executor.py
@@ -28,16 +28,19 @@ import subprocess
 import time
 import traceback
 from multiprocessing import Pool, cpu_count
-from typing import Any, List, Optional, Tuple, Union
+from typing import Any, List, Mapping, MutableMapping, Optional, Set, Tuple, 
Union
 
 from celery import Celery, Task, states as celery_states
+from celery.backends.base import BaseKeyValueStoreBackend
+from celery.backends.database import DatabaseBackend, Task as TaskDb, 
session_cleanup
 from celery.result import AsyncResult
 
 from airflow.config_templates.default_celery import DEFAULT_CELERY_CONFIG
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException
 from airflow.executors.base_executor import BaseExecutor, CommandType
-from airflow.models.taskinstance import SimpleTaskInstance, 
TaskInstanceKeyType, TaskInstanceStateType
+from airflow.models.taskinstance import SimpleTaskInstance, TaskInstanceKeyType
+from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.timeout import timeout
 
 log = logging.getLogger(__name__)
@@ -93,30 +96,6 @@ class ExceptionWithTraceback:
         self.traceback = exception_traceback
 
 
-def fetch_celery_task_state(celery_task: Tuple[TaskInstanceKeyType, 
AsyncResult]) \
-        -> Union[TaskInstanceStateType, ExceptionWithTraceback]:
-    """
-    Fetch and return the state of the given celery task. The scope of this 
function is
-    global so that it can be called by subprocesses in the pool.
-
-    :param celery_task: a tuple of the Celery task key and the async Celery 
object used
-        to fetch the task's state
-    :type celery_task: tuple(str, celery.result.AsyncResult)
-    :return: a tuple of the Celery task key and the Celery state of the task
-    :rtype: tuple[str, str]
-    """
-
-    try:
-        with timeout(seconds=OPERATION_TIMEOUT):
-            # Accessing state property of celery task will make actual network 
request
-            # to get the current state of the task.
-            return celery_task[0], celery_task[1].state
-    except Exception as e:  # pylint: disable=broad-except
-        exception_traceback = "Celery Task ID: {}\n{}".format(celery_task[0],
-                                                              
traceback.format_exc())
-        return ExceptionWithTraceback(e, exception_traceback)
-
-
 # Task instance that is sent over Celery queues
 # TaskInstanceKeyType, SimpleTaskInstance, Command, queue_name, CallableTask
 TaskInstanceInCelery = Tuple[TaskInstanceKeyType, SimpleTaskInstance, 
CommandType, Optional[str], Task]
@@ -149,15 +128,13 @@ class CeleryExecutor(BaseExecutor):
     def __init__(self):
         super().__init__()
 
-        # Celery doesn't support querying the state of multiple tasks in 
parallel
-        # (which can become a bottleneck on bigger clusters) so we use
-        # a multiprocessing pool to speed this up.
+        # Celery doesn't support bulk sending the tasks (which can become a 
bottleneck on bigger clusters)
+        # so we use a multiprocessing pool to speed this up.
         # How many worker processes are created for checking celery task state.
         self._sync_parallelism = conf.getint('celery', 'SYNC_PARALLELISM')
         if self._sync_parallelism == 0:
             self._sync_parallelism = max(1, cpu_count() - 1)
-
-        self._sync_pool = None
+        self.bulk_state_fetcher = BulkStateFetcher(self._sync_parallelism)
         self.tasks = {}
         self.last_state = {}
 
@@ -177,15 +154,6 @@ class CeleryExecutor(BaseExecutor):
         return max(1,
                    int(math.ceil(1.0 * to_send_count / 
self._sync_parallelism)))
 
-    def _num_tasks_per_fetch_process(self) -> int:
-        """
-        How many Celery tasks should be sent to each worker process.
-
-        :return: Number of tasks that should be used per process
-        :rtype: int
-        """
-        return max(1, int(math.ceil(1.0 * len(self.tasks) / 
self._sync_parallelism)))
-
     def trigger_tasks(self, open_slots: int) -> None:
         """
         Overwrite trigger_tasks function from BaseExecutor
@@ -201,7 +169,6 @@ class CeleryExecutor(BaseExecutor):
             key, (command, _, queue, simple_ti) = sorted_queue.pop(0)
             task_tuples_to_send.append((key, simple_ti, command, queue, 
execute_command))
 
-        cached_celery_backend = None
         if task_tuples_to_send:
             tasks = [t[4] for t in task_tuples_to_send]
 
@@ -209,20 +176,7 @@ class CeleryExecutor(BaseExecutor):
             # for all tasks.
             cached_celery_backend = tasks[0].backend
 
-        if task_tuples_to_send:
-            # Use chunks instead of a work queue to reduce context switching
-            # since tasks are roughly uniform in size
-            chunksize = 
self._num_tasks_per_send_process(len(task_tuples_to_send))
-            num_processes = min(len(task_tuples_to_send), 
self._sync_parallelism)
-
-            send_pool = Pool(processes=num_processes)
-            key_and_async_results = send_pool.map(
-                send_task_to_executor,
-                task_tuples_to_send,
-                chunksize=chunksize)
-
-            send_pool.close()
-            send_pool.join()
+            key_and_async_results = 
self._send_tasks_to_celery(task_tuples_to_send)
             self.log.debug('Sent all tasks.')
 
             for key, command, result in key_and_async_results:
@@ -239,46 +193,35 @@ class CeleryExecutor(BaseExecutor):
                     self.tasks[key] = result
                     self.last_state[key] = celery_states.PENDING
 
+    def _send_tasks_to_celery(self, task_tuples_to_send):
+        # Use chunks instead of a work queue to reduce context switching
+        # since tasks are roughly uniform in size
+        chunksize = self._num_tasks_per_send_process(len(task_tuples_to_send))
+        num_processes = min(len(task_tuples_to_send), self._sync_parallelism)
+        with Pool(processes=num_processes) as send_pool:
+            key_and_async_results = send_pool.map(
+                send_task_to_executor,
+                task_tuples_to_send,
+                chunksize=chunksize)
+        return key_and_async_results
+
     def sync(self) -> None:
-        num_processes = min(len(self.tasks), self._sync_parallelism)
-        if num_processes == 0:
+        if not self.tasks:
             self.log.debug("No task to query celery, skipping sync")
             return
+        self.update_all_task_states()
 
-        self.log.debug("Inquiring about %s celery task(s) using %s processes",
-                       len(self.tasks), num_processes)
-
-        # Recreate the process pool each sync in case processes in the pool die
-        self._sync_pool = Pool(processes=num_processes)
+    def update_all_task_states(self) -> None:
+        """Updates states of the tasks."""
 
-        # Use chunks instead of a work queue to reduce context switching since 
tasks are
-        # roughly uniform in size
-        chunksize = self._num_tasks_per_fetch_process()
+        self.log.debug("Inquiring about %s celery task(s)", len(self.tasks))
+        states_by_celery_task_id = 
self.bulk_state_fetcher.get_many(self.tasks.values())
 
-        self.log.debug("Waiting for inquiries to complete...")
-        task_keys_to_states = self._sync_pool.map(
-            fetch_celery_task_state,
-            self.tasks.items(),
-            chunksize=chunksize)
-        self._sync_pool.close()
-        self._sync_pool.join()
         self.log.debug("Inquiries completed.")
-
-        self.update_task_states(task_keys_to_states)
-
-    def update_task_states(self,
-                           task_keys_to_states: 
List[Union[TaskInstanceStateType,
-                                                           
ExceptionWithTraceback]]) -> None:
-        """Updates states of the tasks."""
-        for key_and_state in task_keys_to_states:
-            if isinstance(key_and_state, ExceptionWithTraceback):
-                self.log.error(  # pylint: disable=logging-not-lazy
-                    CELERY_FETCH_ERR_MSG_HEADER + ", ignoring it:%s\n%s\n",
-                    repr(key_and_state.exception), key_and_state.traceback
-                )
-                continue
-            key, state = key_and_state
-            self.update_task_state(key, state)
+        for key, async_result in list(self.tasks.items()):
+            state_by_task_id = 
states_by_celery_task_id.get(async_result.task_id)
+            if state_by_task_id:
+                self.update_task_state(key, state_by_task_id)
 
     def update_task_state(self, key: TaskInstanceKeyType, state: str) -> None:
         """Updates state of a single task."""
@@ -319,3 +262,109 @@ class CeleryExecutor(BaseExecutor):
 
     def terminate(self):
         pass
+
+
+def fetch_celery_task_state(async_result: AsyncResult) -> Tuple[str, 
Union[str, ExceptionWithTraceback]]:
+    """
+    Fetch and return the state of the given celery task. The scope of this 
function is
+    global so that it can be called by subprocesses in the pool.
+
+    :param async_result: a tuple of the Celery task key and the async Celery 
object used
+        to fetch the task's state
+    :type async_result: tuple(str, celery.result.AsyncResult)
+    :return: a tuple of the Celery task key and the Celery state of the task
+    :rtype: tuple[str, str]
+    """
+
+    try:
+        with timeout(seconds=OPERATION_TIMEOUT):
+            # Accessing state property of celery task will make actual network 
request
+            # to get the current state of the task
+            return async_result.task_id, async_result.state
+    except Exception as e:  # pylint: disable=broad-except
+        exception_traceback = f"Celery Task ID: 
{async_result}\n{traceback.format_exc()}"
+        return async_result.task_id, ExceptionWithTraceback(e, 
exception_traceback)
+
+
+def _tasks_list_to_task_ids(async_tasks) -> Set[str]:
+    return {a.task_id for a in async_tasks}
+
+
+class BulkStateFetcher(LoggingMixin):
+    """
+    Gets status for many Celery tasks using the best method available
+
+    If BaseKeyValueStoreBackend is used as result backend, the mget method is 
used.
+    If DatabaseBackend is used as result backend, the SELECT ...WHER task_id 
IN (...) query is used
+    Otherwise, multiprocessing.Pool will be used. Each task status will be 
downloaded individually.
+    """
+    def __init__(self, sync_parralelism=None):
+        super().__init__()
+        self._sync_parallelism = sync_parralelism
+
+    def get_many(self, async_results) -> Mapping[str, str]:
+        """
+        Gets status for many Celery tasks using the best method available.
+        """
+        if isinstance(app.backend, BaseKeyValueStoreBackend):
+            result = self._get_many_from_kv_backend(async_results)
+            return result
+        if isinstance(app.backend, DatabaseBackend):
+            result = self._get_many_from_db_backend(async_results)
+            return result
+        result = self._get_many_using_multiprocessing(async_results)
+        self.log.debug("Fetched %d states for %d task", len(result), 
len(async_results))
+        return result
+
+    def _get_many_from_kv_backend(self, async_tasks) -> Mapping[str, str]:
+        task_ids = _tasks_list_to_task_ids(async_tasks)
+        keys = [app.backend.get_key_for_task(k) for k in task_ids]
+        values = app.backend.mget(keys)
+        task_results = [app.backend.decode_result(v) for v in values if v]
+        task_results_by_task_id = {task_result["task_id"]: task_result for 
task_result in task_results}
+
+        return self._preapre_state_by_task_dict(task_ids, 
task_results_by_task_id)
+
+    def _get_many_from_db_backend(self, async_tasks) -> Mapping[str, str]:
+        task_ids = _tasks_list_to_task_ids(async_tasks)
+        session = app.backend.ResultSession()
+        with session_cleanup(session):
+            tasks = 
session.query(TaskDb).filter(TaskDb.task_id.in_(task_ids)).all()
+
+        task_results = [app.backend.meta_from_decoded(task.to_dict()) for task 
in tasks]
+        task_results_by_task_id = {task_result["task_id"]: task_result for 
task_result in task_results}
+        return self._preapre_state_by_task_dict(task_ids, 
task_results_by_task_id)
+
+    @staticmethod
+    def _preapre_state_by_task_dict(task_ids, task_results_by_task_id) -> 
Mapping[str, str]:
+        states: MutableMapping[str, str] = {}
+        for task_id in task_ids:
+            task_result = task_results_by_task_id.get(task_id)
+            if task_result:
+                state = task_result["status"]
+            else:
+                state = celery_states.PENDING
+            states[task_id] = state
+        return states
+
+    def _get_many_using_multiprocessing(self, async_results) -> Mapping[str, 
str]:
+        num_process = min(len(async_results), self._sync_parallelism)
+
+        with Pool(processes=num_process) as sync_pool:
+            chunksize = max(1, math.floor(math.ceil(1.0 * len(async_results) / 
self._sync_parallelism)))
+
+            task_id_to_states_or_exception = sync_pool.map(
+                fetch_celery_task_state,
+                async_results,
+                chunksize=chunksize)
+
+            states_by_task_id: MutableMapping[str, str] = {}
+            for task_id, state_or_exception in task_id_to_states_or_exception:
+                if isinstance(state_or_exception, ExceptionWithTraceback):
+                    self.log.error(  # pylint: disable=logging-not-lazy
+                        CELERY_FETCH_ERR_MSG_HEADER + ":%s\n%s\n",
+                        state_or_exception.exception, 
state_or_exception.traceback
+                    )
+                else:
+                    states_by_task_id[task_id] = state_or_exception
+        return states_by_task_id
diff --git a/tests/executors/test_celery_executor.py 
b/tests/executors/test_celery_executor.py
index f94ff2c..2890544 100644
--- a/tests/executors/test_celery_executor.py
+++ b/tests/executors/test_celery_executor.py
@@ -17,6 +17,7 @@
 # under the License.
 import contextlib
 import datetime
+import json
 import os
 import sys
 import unittest
@@ -28,12 +29,15 @@ from unittest import mock
 import celery.contrib.testing.tasks  # noqa: F401 pylint: disable=unused-import
 import pytest
 from celery import Celery, states as celery_states
+from celery.backends.base import BaseBackend, BaseKeyValueStoreBackend
+from celery.backends.database import DatabaseBackend
 from celery.contrib.testing.worker import start_worker
 from kombu.asynchronous import set_event_loop
 from parameterized import parameterized
 
 from airflow.configuration import conf
 from airflow.executors import celery_executor
+from airflow.executors.celery_executor import BulkStateFetcher
 from airflow.models import TaskInstance
 from airflow.models.dag import DAG
 from airflow.models.taskinstance import SimpleTaskInstance
@@ -50,33 +54,43 @@ def _prepare_test_bodies():
     return [(conf.get('celery', 'BROKER_URL'))]
 
 
-class TestCeleryExecutor(unittest.TestCase):
+class FakeCeleryResult:
+    @property
+    def state(self):
+        raise Exception()
+
+    def task_id(self):
+        return "task_id"
+
+
[email protected]
+def _prepare_app(broker_url=None, execute=None):
+    broker_url = broker_url or conf.get('celery', 'BROKER_URL')
+    execute = execute or celery_executor.execute_command.__wrapped__
+
+    test_config = dict(celery_executor.celery_configuration)
+    test_config.update({'broker_url': broker_url})
+    test_app = Celery(broker_url, config_source=test_config)
+    test_execute = test_app.task(execute)
+    patch_app = mock.patch('airflow.executors.celery_executor.app', test_app)
+    patch_execute = 
mock.patch('airflow.executors.celery_executor.execute_command', test_execute)
 
-    @contextlib.contextmanager
-    def _prepare_app(self, broker_url=None, execute=None):
-        broker_url = broker_url or conf.get('celery', 'BROKER_URL')
-        execute = execute or celery_executor.execute_command.__wrapped__
-
-        test_config = dict(celery_executor.celery_configuration)
-        test_config.update({'broker_url': broker_url})
-        test_app = Celery(broker_url, config_source=test_config)
-        test_execute = test_app.task(execute)
-        patch_app = mock.patch('airflow.executors.celery_executor.app', 
test_app)
-        patch_execute = 
mock.patch('airflow.executors.celery_executor.execute_command', test_execute)
-
-        with patch_app, patch_execute:
-            try:
-                yield test_app
-            finally:
-                # Clear event loop to tear down each celery instance
-                set_event_loop(None)
+    with patch_app, patch_execute:
+        try:
+            yield test_app
+        finally:
+            # Clear event loop to tear down each celery instance
+            set_event_loop(None)
+
+
+class TestCeleryExecutor(unittest.TestCase):
 
     @parameterized.expand(_prepare_test_bodies())
     @pytest.mark.integration("redis")
     @pytest.mark.integration("rabbitmq")
     @pytest.mark.backend("mysql", "postgres")
     def test_celery_integration(self, broker_url):
-        with self._prepare_app(broker_url) as app:
+        with _prepare_app(broker_url) as app:
             executor = celery_executor.CeleryExecutor()
             executor.start()
 
@@ -98,14 +112,11 @@ class TestCeleryExecutor(unittest.TestCase):
                 chunksize = 
executor._num_tasks_per_send_process(len(task_tuples_to_send))
                 num_processes = min(len(task_tuples_to_send), 
executor._sync_parallelism)
 
-                send_pool = Pool(processes=num_processes)
-                key_and_async_results = send_pool.map(
-                    celery_executor.send_task_to_executor,
-                    task_tuples_to_send,
-                    chunksize=chunksize)
-
-                send_pool.close()
-                send_pool.join()
+                with Pool(processes=num_processes) as send_pool:
+                    key_and_async_results = send_pool.map(
+                        celery_executor.send_task_to_executor,
+                        task_tuples_to_send,
+                        chunksize=chunksize)
 
                 for task_instance_key, _, result in key_and_async_results:
                     # Only pops when enqueued successfully, otherwise keep it
@@ -136,7 +147,7 @@ class TestCeleryExecutor(unittest.TestCase):
         def fake_execute_command():
             pass
 
-        with self._prepare_app(execute=fake_execute_command):
+        with _prepare_app(execute=fake_execute_command):
             # fake_execute_command takes no arguments while execute_command 
takes 1,
             # which will cause TypeError when calling task.apply_async()
             executor = celery_executor.CeleryExecutor()
@@ -155,25 +166,16 @@ class TestCeleryExecutor(unittest.TestCase):
         self.assertEqual(executor.queued_tasks[key], value_tuple)
 
     def test_exception_propagation(self):
-        with self._prepare_app() as app:
-            @app.task
-            def fake_celery_task():
-                return {}
 
-            mock_log = mock.MagicMock()
+        with _prepare_app(), self.assertLogs(celery_executor.log) as cm:
             executor = celery_executor.CeleryExecutor()
-            executor._log = mock_log
-
-            executor.tasks = {'key': fake_celery_task()}
-            executor.sync()
+            executor.tasks = {
+                'key': FakeCeleryResult()
+            }
+            
executor.bulk_state_fetcher._get_many_using_multiprocessing(executor.tasks.values())
 
-        assert mock_log.error.call_count == 1
-        args, kwargs = mock_log.error.call_args_list[0]
-        # Result of queuing is not a celery task but a dict,
-        # and it should raise AttributeError and then get propagated
-        # to the error log.
-        self.assertIn(celery_executor.CELERY_FETCH_ERR_MSG_HEADER, args[0])
-        self.assertIn('AttributeError', args[1])
+        self.assertTrue(any(celery_executor.CELERY_FETCH_ERR_MSG_HEADER in 
line for line in cm.output))
+        self.assertTrue(any("Exception" in line for line in cm.output))
 
     @mock.patch('airflow.executors.celery_executor.CeleryExecutor.sync')
     
@mock.patch('airflow.executors.celery_executor.CeleryExecutor.trigger_tasks')
@@ -189,3 +191,87 @@ class TestCeleryExecutor(unittest.TestCase):
 
 def test_operation_timeout_config():
     assert celery_executor.OPERATION_TIMEOUT == 2
+
+
+class ClassWithCustomAttributes:
+    """Class for testing purpose: allows to create objects with custom 
attributes in one single statement."""
+
+    def __init__(self, **kwargs):
+        for key, value in kwargs.items():
+            setattr(self, key, value)
+
+    def __str__(self):
+        return "{}({})".format(ClassWithCustomAttributes.__name__, 
str(self.__dict__))
+
+    def __repr__(self):
+        return self.__str__()
+
+    def __eq__(self, other):
+        return self.__dict__ == other.__dict__
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
+
+class TestBulkStateFetcher(unittest.TestCase):
+
+    @mock.patch("celery.backends.base.BaseKeyValueStoreBackend.mget", 
return_value=[
+        json.dumps({"status": "SUCCESS", "task_id": "123"})
+    ])
+    @pytest.mark.integration("redis")
+    @pytest.mark.integration("rabbitmq")
+    @pytest.mark.backend("mysql", "postgres")
+    def test_should_support_kv_backend(self, mock_mget):
+        with _prepare_app():
+            mock_backend = BaseKeyValueStoreBackend(app=celery_executor.app)
+            with mock.patch.object(celery_executor.app, 'backend', 
mock_backend):
+                fetcher = BulkStateFetcher()
+                result = fetcher.get_many([
+                    mock.MagicMock(task_id="123"),
+                    mock.MagicMock(task_id="456"),
+                ])
+
+        # Assert called - ignore order
+        mget_args, _ = mock_mget.call_args
+        self.assertEqual(set(mget_args[0]), {b'celery-task-meta-456', 
b'celery-task-meta-123'})
+        mock_mget.assert_called_once_with(mock.ANY)
+
+        self.assertEqual(result, {'123': 'SUCCESS', '456': "PENDING"})
+
+    @mock.patch("celery.backends.database.DatabaseBackend.ResultSession")
+    @pytest.mark.integration("redis")
+    @pytest.mark.integration("rabbitmq")
+    @pytest.mark.backend("mysql", "postgres")
+    def test_should_support_db_backend(self, mock_session):
+        with _prepare_app():
+            mock_backend = DatabaseBackend(app=celery_executor.app, 
url="sqlite3://")
+
+            with mock.patch.object(celery_executor.app, 'backend', 
mock_backend):
+                mock_session = mock_backend.ResultSession.return_value  # 
pylint: disable=no-member
+                
mock_session.query.return_value.filter.return_value.all.return_value = [
+                    mock.MagicMock(**{"to_dict.return_value": {"status": 
"SUCCESS", "task_id": "123"}})
+                ]
+
+        fetcher = BulkStateFetcher()
+        result = fetcher.get_many([
+            mock.MagicMock(task_id="123"),
+            mock.MagicMock(task_id="456"),
+        ])
+
+        self.assertEqual(result, {'123': 'SUCCESS', '456': "PENDING"})
+
+    @pytest.mark.integration("redis")
+    @pytest.mark.integration("rabbitmq")
+    @pytest.mark.backend("mysql", "postgres")
+    def test_should_support_base_backend(self):
+        with _prepare_app():
+            mock_backend = mock.MagicMock(autospec=BaseBackend)
+
+            with mock.patch.object(celery_executor.app, 'backend', 
mock_backend):
+                fetcher = BulkStateFetcher(1)
+                result = fetcher.get_many([
+                    ClassWithCustomAttributes(task_id="123", state='SUCCESS'),
+                    ClassWithCustomAttributes(task_id="456", state="PENDING"),
+                ])
+
+        self.assertEqual(result, {'123': 'SUCCESS', '456': "PENDING"})

Reply via email to