This is an automated email from the ASF dual-hosted git repository.
kamilbregula pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new 5f3774a [AIRFLOW-6921] Fetch celery states in bulk (#7542)
5f3774a is described below
commit 5f3774aebdce8fe7bdcc478ea0b05d95c37c9440
Author: Kamil BreguĊa <[email protected]>
AuthorDate: Mon May 11 10:44:14 2020 +0200
[AIRFLOW-6921] Fetch celery states in bulk (#7542)
---
airflow/executors/celery_executor.py | 225 +++++++++++++++++++-------------
tests/executors/test_celery_executor.py | 176 ++++++++++++++++++-------
2 files changed, 268 insertions(+), 133 deletions(-)
diff --git a/airflow/executors/celery_executor.py
b/airflow/executors/celery_executor.py
index 40f173b..98fff4b 100644
--- a/airflow/executors/celery_executor.py
+++ b/airflow/executors/celery_executor.py
@@ -28,16 +28,19 @@ import subprocess
import time
import traceback
from multiprocessing import Pool, cpu_count
-from typing import Any, List, Optional, Tuple, Union
+from typing import Any, List, Mapping, MutableMapping, Optional, Set, Tuple,
Union
from celery import Celery, Task, states as celery_states
+from celery.backends.base import BaseKeyValueStoreBackend
+from celery.backends.database import DatabaseBackend, Task as TaskDb,
session_cleanup
from celery.result import AsyncResult
from airflow.config_templates.default_celery import DEFAULT_CELERY_CONFIG
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.executors.base_executor import BaseExecutor, CommandType
-from airflow.models.taskinstance import SimpleTaskInstance,
TaskInstanceKeyType, TaskInstanceStateType
+from airflow.models.taskinstance import SimpleTaskInstance, TaskInstanceKeyType
+from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.timeout import timeout
log = logging.getLogger(__name__)
@@ -93,30 +96,6 @@ class ExceptionWithTraceback:
self.traceback = exception_traceback
-def fetch_celery_task_state(celery_task: Tuple[TaskInstanceKeyType,
AsyncResult]) \
- -> Union[TaskInstanceStateType, ExceptionWithTraceback]:
- """
- Fetch and return the state of the given celery task. The scope of this
function is
- global so that it can be called by subprocesses in the pool.
-
- :param celery_task: a tuple of the Celery task key and the async Celery
object used
- to fetch the task's state
- :type celery_task: tuple(str, celery.result.AsyncResult)
- :return: a tuple of the Celery task key and the Celery state of the task
- :rtype: tuple[str, str]
- """
-
- try:
- with timeout(seconds=OPERATION_TIMEOUT):
- # Accessing state property of celery task will make actual network
request
- # to get the current state of the task.
- return celery_task[0], celery_task[1].state
- except Exception as e: # pylint: disable=broad-except
- exception_traceback = "Celery Task ID: {}\n{}".format(celery_task[0],
-
traceback.format_exc())
- return ExceptionWithTraceback(e, exception_traceback)
-
-
# Task instance that is sent over Celery queues
# TaskInstanceKeyType, SimpleTaskInstance, Command, queue_name, CallableTask
TaskInstanceInCelery = Tuple[TaskInstanceKeyType, SimpleTaskInstance,
CommandType, Optional[str], Task]
@@ -149,15 +128,13 @@ class CeleryExecutor(BaseExecutor):
def __init__(self):
super().__init__()
- # Celery doesn't support querying the state of multiple tasks in
parallel
- # (which can become a bottleneck on bigger clusters) so we use
- # a multiprocessing pool to speed this up.
+ # Celery doesn't support bulk sending the tasks (which can become a
bottleneck on bigger clusters)
+ # so we use a multiprocessing pool to speed this up.
# How many worker processes are created for checking celery task state.
self._sync_parallelism = conf.getint('celery', 'SYNC_PARALLELISM')
if self._sync_parallelism == 0:
self._sync_parallelism = max(1, cpu_count() - 1)
-
- self._sync_pool = None
+ self.bulk_state_fetcher = BulkStateFetcher(self._sync_parallelism)
self.tasks = {}
self.last_state = {}
@@ -177,15 +154,6 @@ class CeleryExecutor(BaseExecutor):
return max(1,
int(math.ceil(1.0 * to_send_count /
self._sync_parallelism)))
- def _num_tasks_per_fetch_process(self) -> int:
- """
- How many Celery tasks should be sent to each worker process.
-
- :return: Number of tasks that should be used per process
- :rtype: int
- """
- return max(1, int(math.ceil(1.0 * len(self.tasks) /
self._sync_parallelism)))
-
def trigger_tasks(self, open_slots: int) -> None:
"""
Overwrite trigger_tasks function from BaseExecutor
@@ -201,7 +169,6 @@ class CeleryExecutor(BaseExecutor):
key, (command, _, queue, simple_ti) = sorted_queue.pop(0)
task_tuples_to_send.append((key, simple_ti, command, queue,
execute_command))
- cached_celery_backend = None
if task_tuples_to_send:
tasks = [t[4] for t in task_tuples_to_send]
@@ -209,20 +176,7 @@ class CeleryExecutor(BaseExecutor):
# for all tasks.
cached_celery_backend = tasks[0].backend
- if task_tuples_to_send:
- # Use chunks instead of a work queue to reduce context switching
- # since tasks are roughly uniform in size
- chunksize =
self._num_tasks_per_send_process(len(task_tuples_to_send))
- num_processes = min(len(task_tuples_to_send),
self._sync_parallelism)
-
- send_pool = Pool(processes=num_processes)
- key_and_async_results = send_pool.map(
- send_task_to_executor,
- task_tuples_to_send,
- chunksize=chunksize)
-
- send_pool.close()
- send_pool.join()
+ key_and_async_results =
self._send_tasks_to_celery(task_tuples_to_send)
self.log.debug('Sent all tasks.')
for key, command, result in key_and_async_results:
@@ -239,46 +193,35 @@ class CeleryExecutor(BaseExecutor):
self.tasks[key] = result
self.last_state[key] = celery_states.PENDING
+ def _send_tasks_to_celery(self, task_tuples_to_send):
+ # Use chunks instead of a work queue to reduce context switching
+ # since tasks are roughly uniform in size
+ chunksize = self._num_tasks_per_send_process(len(task_tuples_to_send))
+ num_processes = min(len(task_tuples_to_send), self._sync_parallelism)
+ with Pool(processes=num_processes) as send_pool:
+ key_and_async_results = send_pool.map(
+ send_task_to_executor,
+ task_tuples_to_send,
+ chunksize=chunksize)
+ return key_and_async_results
+
def sync(self) -> None:
- num_processes = min(len(self.tasks), self._sync_parallelism)
- if num_processes == 0:
+ if not self.tasks:
self.log.debug("No task to query celery, skipping sync")
return
+ self.update_all_task_states()
- self.log.debug("Inquiring about %s celery task(s) using %s processes",
- len(self.tasks), num_processes)
-
- # Recreate the process pool each sync in case processes in the pool die
- self._sync_pool = Pool(processes=num_processes)
+ def update_all_task_states(self) -> None:
+ """Updates states of the tasks."""
- # Use chunks instead of a work queue to reduce context switching since
tasks are
- # roughly uniform in size
- chunksize = self._num_tasks_per_fetch_process()
+ self.log.debug("Inquiring about %s celery task(s)", len(self.tasks))
+ states_by_celery_task_id =
self.bulk_state_fetcher.get_many(self.tasks.values())
- self.log.debug("Waiting for inquiries to complete...")
- task_keys_to_states = self._sync_pool.map(
- fetch_celery_task_state,
- self.tasks.items(),
- chunksize=chunksize)
- self._sync_pool.close()
- self._sync_pool.join()
self.log.debug("Inquiries completed.")
-
- self.update_task_states(task_keys_to_states)
-
- def update_task_states(self,
- task_keys_to_states:
List[Union[TaskInstanceStateType,
-
ExceptionWithTraceback]]) -> None:
- """Updates states of the tasks."""
- for key_and_state in task_keys_to_states:
- if isinstance(key_and_state, ExceptionWithTraceback):
- self.log.error( # pylint: disable=logging-not-lazy
- CELERY_FETCH_ERR_MSG_HEADER + ", ignoring it:%s\n%s\n",
- repr(key_and_state.exception), key_and_state.traceback
- )
- continue
- key, state = key_and_state
- self.update_task_state(key, state)
+ for key, async_result in list(self.tasks.items()):
+ state_by_task_id =
states_by_celery_task_id.get(async_result.task_id)
+ if state_by_task_id:
+ self.update_task_state(key, state_by_task_id)
def update_task_state(self, key: TaskInstanceKeyType, state: str) -> None:
"""Updates state of a single task."""
@@ -319,3 +262,109 @@ class CeleryExecutor(BaseExecutor):
def terminate(self):
pass
+
+
+def fetch_celery_task_state(async_result: AsyncResult) -> Tuple[str,
Union[str, ExceptionWithTraceback]]:
+ """
+ Fetch and return the state of the given celery task. The scope of this
function is
+ global so that it can be called by subprocesses in the pool.
+
+ :param async_result: a tuple of the Celery task key and the async Celery
object used
+ to fetch the task's state
+ :type async_result: tuple(str, celery.result.AsyncResult)
+ :return: a tuple of the Celery task key and the Celery state of the task
+ :rtype: tuple[str, str]
+ """
+
+ try:
+ with timeout(seconds=OPERATION_TIMEOUT):
+ # Accessing state property of celery task will make actual network
request
+ # to get the current state of the task
+ return async_result.task_id, async_result.state
+ except Exception as e: # pylint: disable=broad-except
+ exception_traceback = f"Celery Task ID:
{async_result}\n{traceback.format_exc()}"
+ return async_result.task_id, ExceptionWithTraceback(e,
exception_traceback)
+
+
+def _tasks_list_to_task_ids(async_tasks) -> Set[str]:
+ return {a.task_id for a in async_tasks}
+
+
+class BulkStateFetcher(LoggingMixin):
+ """
+ Gets status for many Celery tasks using the best method available
+
+ If BaseKeyValueStoreBackend is used as result backend, the mget method is
used.
+ If DatabaseBackend is used as result backend, the SELECT ...WHER task_id
IN (...) query is used
+ Otherwise, multiprocessing.Pool will be used. Each task status will be
downloaded individually.
+ """
+ def __init__(self, sync_parralelism=None):
+ super().__init__()
+ self._sync_parallelism = sync_parralelism
+
+ def get_many(self, async_results) -> Mapping[str, str]:
+ """
+ Gets status for many Celery tasks using the best method available.
+ """
+ if isinstance(app.backend, BaseKeyValueStoreBackend):
+ result = self._get_many_from_kv_backend(async_results)
+ return result
+ if isinstance(app.backend, DatabaseBackend):
+ result = self._get_many_from_db_backend(async_results)
+ return result
+ result = self._get_many_using_multiprocessing(async_results)
+ self.log.debug("Fetched %d states for %d task", len(result),
len(async_results))
+ return result
+
+ def _get_many_from_kv_backend(self, async_tasks) -> Mapping[str, str]:
+ task_ids = _tasks_list_to_task_ids(async_tasks)
+ keys = [app.backend.get_key_for_task(k) for k in task_ids]
+ values = app.backend.mget(keys)
+ task_results = [app.backend.decode_result(v) for v in values if v]
+ task_results_by_task_id = {task_result["task_id"]: task_result for
task_result in task_results}
+
+ return self._preapre_state_by_task_dict(task_ids,
task_results_by_task_id)
+
+ def _get_many_from_db_backend(self, async_tasks) -> Mapping[str, str]:
+ task_ids = _tasks_list_to_task_ids(async_tasks)
+ session = app.backend.ResultSession()
+ with session_cleanup(session):
+ tasks =
session.query(TaskDb).filter(TaskDb.task_id.in_(task_ids)).all()
+
+ task_results = [app.backend.meta_from_decoded(task.to_dict()) for task
in tasks]
+ task_results_by_task_id = {task_result["task_id"]: task_result for
task_result in task_results}
+ return self._preapre_state_by_task_dict(task_ids,
task_results_by_task_id)
+
+ @staticmethod
+ def _preapre_state_by_task_dict(task_ids, task_results_by_task_id) ->
Mapping[str, str]:
+ states: MutableMapping[str, str] = {}
+ for task_id in task_ids:
+ task_result = task_results_by_task_id.get(task_id)
+ if task_result:
+ state = task_result["status"]
+ else:
+ state = celery_states.PENDING
+ states[task_id] = state
+ return states
+
+ def _get_many_using_multiprocessing(self, async_results) -> Mapping[str,
str]:
+ num_process = min(len(async_results), self._sync_parallelism)
+
+ with Pool(processes=num_process) as sync_pool:
+ chunksize = max(1, math.floor(math.ceil(1.0 * len(async_results) /
self._sync_parallelism)))
+
+ task_id_to_states_or_exception = sync_pool.map(
+ fetch_celery_task_state,
+ async_results,
+ chunksize=chunksize)
+
+ states_by_task_id: MutableMapping[str, str] = {}
+ for task_id, state_or_exception in task_id_to_states_or_exception:
+ if isinstance(state_or_exception, ExceptionWithTraceback):
+ self.log.error( # pylint: disable=logging-not-lazy
+ CELERY_FETCH_ERR_MSG_HEADER + ":%s\n%s\n",
+ state_or_exception.exception,
state_or_exception.traceback
+ )
+ else:
+ states_by_task_id[task_id] = state_or_exception
+ return states_by_task_id
diff --git a/tests/executors/test_celery_executor.py
b/tests/executors/test_celery_executor.py
index f94ff2c..2890544 100644
--- a/tests/executors/test_celery_executor.py
+++ b/tests/executors/test_celery_executor.py
@@ -17,6 +17,7 @@
# under the License.
import contextlib
import datetime
+import json
import os
import sys
import unittest
@@ -28,12 +29,15 @@ from unittest import mock
import celery.contrib.testing.tasks # noqa: F401 pylint: disable=unused-import
import pytest
from celery import Celery, states as celery_states
+from celery.backends.base import BaseBackend, BaseKeyValueStoreBackend
+from celery.backends.database import DatabaseBackend
from celery.contrib.testing.worker import start_worker
from kombu.asynchronous import set_event_loop
from parameterized import parameterized
from airflow.configuration import conf
from airflow.executors import celery_executor
+from airflow.executors.celery_executor import BulkStateFetcher
from airflow.models import TaskInstance
from airflow.models.dag import DAG
from airflow.models.taskinstance import SimpleTaskInstance
@@ -50,33 +54,43 @@ def _prepare_test_bodies():
return [(conf.get('celery', 'BROKER_URL'))]
-class TestCeleryExecutor(unittest.TestCase):
+class FakeCeleryResult:
+ @property
+ def state(self):
+ raise Exception()
+
+ def task_id(self):
+ return "task_id"
+
+
[email protected]
+def _prepare_app(broker_url=None, execute=None):
+ broker_url = broker_url or conf.get('celery', 'BROKER_URL')
+ execute = execute or celery_executor.execute_command.__wrapped__
+
+ test_config = dict(celery_executor.celery_configuration)
+ test_config.update({'broker_url': broker_url})
+ test_app = Celery(broker_url, config_source=test_config)
+ test_execute = test_app.task(execute)
+ patch_app = mock.patch('airflow.executors.celery_executor.app', test_app)
+ patch_execute =
mock.patch('airflow.executors.celery_executor.execute_command', test_execute)
- @contextlib.contextmanager
- def _prepare_app(self, broker_url=None, execute=None):
- broker_url = broker_url or conf.get('celery', 'BROKER_URL')
- execute = execute or celery_executor.execute_command.__wrapped__
-
- test_config = dict(celery_executor.celery_configuration)
- test_config.update({'broker_url': broker_url})
- test_app = Celery(broker_url, config_source=test_config)
- test_execute = test_app.task(execute)
- patch_app = mock.patch('airflow.executors.celery_executor.app',
test_app)
- patch_execute =
mock.patch('airflow.executors.celery_executor.execute_command', test_execute)
-
- with patch_app, patch_execute:
- try:
- yield test_app
- finally:
- # Clear event loop to tear down each celery instance
- set_event_loop(None)
+ with patch_app, patch_execute:
+ try:
+ yield test_app
+ finally:
+ # Clear event loop to tear down each celery instance
+ set_event_loop(None)
+
+
+class TestCeleryExecutor(unittest.TestCase):
@parameterized.expand(_prepare_test_bodies())
@pytest.mark.integration("redis")
@pytest.mark.integration("rabbitmq")
@pytest.mark.backend("mysql", "postgres")
def test_celery_integration(self, broker_url):
- with self._prepare_app(broker_url) as app:
+ with _prepare_app(broker_url) as app:
executor = celery_executor.CeleryExecutor()
executor.start()
@@ -98,14 +112,11 @@ class TestCeleryExecutor(unittest.TestCase):
chunksize =
executor._num_tasks_per_send_process(len(task_tuples_to_send))
num_processes = min(len(task_tuples_to_send),
executor._sync_parallelism)
- send_pool = Pool(processes=num_processes)
- key_and_async_results = send_pool.map(
- celery_executor.send_task_to_executor,
- task_tuples_to_send,
- chunksize=chunksize)
-
- send_pool.close()
- send_pool.join()
+ with Pool(processes=num_processes) as send_pool:
+ key_and_async_results = send_pool.map(
+ celery_executor.send_task_to_executor,
+ task_tuples_to_send,
+ chunksize=chunksize)
for task_instance_key, _, result in key_and_async_results:
# Only pops when enqueued successfully, otherwise keep it
@@ -136,7 +147,7 @@ class TestCeleryExecutor(unittest.TestCase):
def fake_execute_command():
pass
- with self._prepare_app(execute=fake_execute_command):
+ with _prepare_app(execute=fake_execute_command):
# fake_execute_command takes no arguments while execute_command
takes 1,
# which will cause TypeError when calling task.apply_async()
executor = celery_executor.CeleryExecutor()
@@ -155,25 +166,16 @@ class TestCeleryExecutor(unittest.TestCase):
self.assertEqual(executor.queued_tasks[key], value_tuple)
def test_exception_propagation(self):
- with self._prepare_app() as app:
- @app.task
- def fake_celery_task():
- return {}
- mock_log = mock.MagicMock()
+ with _prepare_app(), self.assertLogs(celery_executor.log) as cm:
executor = celery_executor.CeleryExecutor()
- executor._log = mock_log
-
- executor.tasks = {'key': fake_celery_task()}
- executor.sync()
+ executor.tasks = {
+ 'key': FakeCeleryResult()
+ }
+
executor.bulk_state_fetcher._get_many_using_multiprocessing(executor.tasks.values())
- assert mock_log.error.call_count == 1
- args, kwargs = mock_log.error.call_args_list[0]
- # Result of queuing is not a celery task but a dict,
- # and it should raise AttributeError and then get propagated
- # to the error log.
- self.assertIn(celery_executor.CELERY_FETCH_ERR_MSG_HEADER, args[0])
- self.assertIn('AttributeError', args[1])
+ self.assertTrue(any(celery_executor.CELERY_FETCH_ERR_MSG_HEADER in
line for line in cm.output))
+ self.assertTrue(any("Exception" in line for line in cm.output))
@mock.patch('airflow.executors.celery_executor.CeleryExecutor.sync')
@mock.patch('airflow.executors.celery_executor.CeleryExecutor.trigger_tasks')
@@ -189,3 +191,87 @@ class TestCeleryExecutor(unittest.TestCase):
def test_operation_timeout_config():
assert celery_executor.OPERATION_TIMEOUT == 2
+
+
+class ClassWithCustomAttributes:
+ """Class for testing purpose: allows to create objects with custom
attributes in one single statement."""
+
+ def __init__(self, **kwargs):
+ for key, value in kwargs.items():
+ setattr(self, key, value)
+
+ def __str__(self):
+ return "{}({})".format(ClassWithCustomAttributes.__name__,
str(self.__dict__))
+
+ def __repr__(self):
+ return self.__str__()
+
+ def __eq__(self, other):
+ return self.__dict__ == other.__dict__
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+
+class TestBulkStateFetcher(unittest.TestCase):
+
+ @mock.patch("celery.backends.base.BaseKeyValueStoreBackend.mget",
return_value=[
+ json.dumps({"status": "SUCCESS", "task_id": "123"})
+ ])
+ @pytest.mark.integration("redis")
+ @pytest.mark.integration("rabbitmq")
+ @pytest.mark.backend("mysql", "postgres")
+ def test_should_support_kv_backend(self, mock_mget):
+ with _prepare_app():
+ mock_backend = BaseKeyValueStoreBackend(app=celery_executor.app)
+ with mock.patch.object(celery_executor.app, 'backend',
mock_backend):
+ fetcher = BulkStateFetcher()
+ result = fetcher.get_many([
+ mock.MagicMock(task_id="123"),
+ mock.MagicMock(task_id="456"),
+ ])
+
+ # Assert called - ignore order
+ mget_args, _ = mock_mget.call_args
+ self.assertEqual(set(mget_args[0]), {b'celery-task-meta-456',
b'celery-task-meta-123'})
+ mock_mget.assert_called_once_with(mock.ANY)
+
+ self.assertEqual(result, {'123': 'SUCCESS', '456': "PENDING"})
+
+ @mock.patch("celery.backends.database.DatabaseBackend.ResultSession")
+ @pytest.mark.integration("redis")
+ @pytest.mark.integration("rabbitmq")
+ @pytest.mark.backend("mysql", "postgres")
+ def test_should_support_db_backend(self, mock_session):
+ with _prepare_app():
+ mock_backend = DatabaseBackend(app=celery_executor.app,
url="sqlite3://")
+
+ with mock.patch.object(celery_executor.app, 'backend',
mock_backend):
+ mock_session = mock_backend.ResultSession.return_value #
pylint: disable=no-member
+
mock_session.query.return_value.filter.return_value.all.return_value = [
+ mock.MagicMock(**{"to_dict.return_value": {"status":
"SUCCESS", "task_id": "123"}})
+ ]
+
+ fetcher = BulkStateFetcher()
+ result = fetcher.get_many([
+ mock.MagicMock(task_id="123"),
+ mock.MagicMock(task_id="456"),
+ ])
+
+ self.assertEqual(result, {'123': 'SUCCESS', '456': "PENDING"})
+
+ @pytest.mark.integration("redis")
+ @pytest.mark.integration("rabbitmq")
+ @pytest.mark.backend("mysql", "postgres")
+ def test_should_support_base_backend(self):
+ with _prepare_app():
+ mock_backend = mock.MagicMock(autospec=BaseBackend)
+
+ with mock.patch.object(celery_executor.app, 'backend',
mock_backend):
+ fetcher = BulkStateFetcher(1)
+ result = fetcher.get_many([
+ ClassWithCustomAttributes(task_id="123", state='SUCCESS'),
+ ClassWithCustomAttributes(task_id="456", state="PENDING"),
+ ])
+
+ self.assertEqual(result, {'123': 'SUCCESS', '456': "PENDING"})