This is an automated email from the ASF dual-hosted git repository. ash pushed a commit to branch v2-0-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 22b2a800ba81e2a90ef40b7a92eb80d4eb67acb2 Author: Ryan Hatter <[email protected]> AuthorDate: Tue Apr 6 05:21:38 2021 -0400 Fix celery executor bug trying to call len on map (#14883) Co-authored-by: RNHTTR <[email protected]> (cherry picked from commit 4ee442970873ba59ee1d1de3ac78ef8e33666e0f) --- airflow/executors/celery_executor.py | 22 ++++++++++----------- tests/executors/test_celery_executor.py | 35 +++++++++++++++++++++++---------- 2 files changed, 35 insertions(+), 22 deletions(-) diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py index a670294..2d0e915 100644 --- a/airflow/executors/celery_executor.py +++ b/airflow/executors/celery_executor.py @@ -476,7 +476,7 @@ class CeleryExecutor(BaseExecutor): return tis states_by_celery_task_id = self.bulk_state_fetcher.get_many( - map(operator.itemgetter(0), celery_tasks.values()) + list(map(operator.itemgetter(0), celery_tasks.values())) ) adopted = [] @@ -526,10 +526,6 @@ def fetch_celery_task_state(async_result: AsyncResult) -> Tuple[str, Union[str, return async_result.task_id, ExceptionWithTraceback(e, exception_traceback), None -def _tasks_list_to_task_ids(async_tasks) -> Set[str]: - return {a.task_id for a in async_tasks} - - class BulkStateFetcher(LoggingMixin): """ Gets status for many Celery tasks using the best method available @@ -543,20 +539,22 @@ class BulkStateFetcher(LoggingMixin): super().__init__() self._sync_parallelism = sync_parralelism + def _tasks_list_to_task_ids(self, async_tasks) -> Set[str]: + return {a.task_id for a in async_tasks} + def get_many(self, async_results) -> Mapping[str, EventBufferValueType]: """Gets status for many Celery tasks using the best method available.""" if isinstance(app.backend, BaseKeyValueStoreBackend): result = self._get_many_from_kv_backend(async_results) - return result - if isinstance(app.backend, DatabaseBackend): + elif isinstance(app.backend, DatabaseBackend): result = self._get_many_from_db_backend(async_results) - return result - result = self._get_many_using_multiprocessing(async_results) - self.log.debug("Fetched %d states for %d task", len(result), len(async_results)) + else: + result = self._get_many_using_multiprocessing(async_results) + self.log.debug("Fetched %d state(s) for %d task(s)", len(result), len(async_results)) return result def _get_many_from_kv_backend(self, async_tasks) -> Mapping[str, EventBufferValueType]: - task_ids = _tasks_list_to_task_ids(async_tasks) + task_ids = self._tasks_list_to_task_ids(async_tasks) keys = [app.backend.get_key_for_task(k) for k in task_ids] values = app.backend.mget(keys) task_results = [app.backend.decode_result(v) for v in values if v] @@ -565,7 +563,7 @@ class BulkStateFetcher(LoggingMixin): return self._prepare_state_and_info_by_task_dict(task_ids, task_results_by_task_id) def _get_many_from_db_backend(self, async_tasks) -> Mapping[str, EventBufferValueType]: - task_ids = _tasks_list_to_task_ids(async_tasks) + task_ids = self._tasks_list_to_task_ids(async_tasks) session = app.backend.ResultSession() task_cls = getattr(app.backend, "task_cls", TaskDb) with session_cleanup(session): diff --git a/tests/executors/test_celery_executor.py b/tests/executors/test_celery_executor.py index 944fa49..4f93007 100644 --- a/tests/executors/test_celery_executor.py +++ b/tests/executors/test_celery_executor.py @@ -414,7 +414,9 @@ class TestBulkStateFetcher(unittest.TestCase): def test_should_support_kv_backend(self, mock_mget): with _prepare_app(): mock_backend = BaseKeyValueStoreBackend(app=celery_executor.app) - with mock.patch.object(celery_executor.app, 'backend', mock_backend): + with mock.patch.object(celery_executor.app, 'backend', mock_backend), self.assertLogs( + "airflow.executors.celery_executor.BulkStateFetcher", level="DEBUG" + ) as cm: fetcher = BulkStateFetcher() result = fetcher.get_many( [ @@ -429,6 +431,9 @@ class TestBulkStateFetcher(unittest.TestCase): mock_mget.assert_called_once_with(mock.ANY) assert result == {'123': ('SUCCESS', None), '456': ("PENDING", None)} + assert [ + 'DEBUG:airflow.executors.celery_executor.BulkStateFetcher:Fetched 2 state(s) for 2 task(s)' + ] == cm.output @mock.patch("celery.backends.database.DatabaseBackend.ResultSession") @pytest.mark.integration("redis") @@ -438,21 +443,26 @@ class TestBulkStateFetcher(unittest.TestCase): with _prepare_app(): mock_backend = DatabaseBackend(app=celery_executor.app, url="sqlite3://") - with mock.patch.object(celery_executor.app, 'backend', mock_backend): + with mock.patch.object(celery_executor.app, 'backend', mock_backend), self.assertLogs( + "airflow.executors.celery_executor.BulkStateFetcher", level="DEBUG" + ) as cm: mock_session = mock_backend.ResultSession.return_value # pylint: disable=no-member mock_session.query.return_value.filter.return_value.all.return_value = [ mock.MagicMock(**{"to_dict.return_value": {"status": "SUCCESS", "task_id": "123"}}) ] - fetcher = BulkStateFetcher() - result = fetcher.get_many( - [ - mock.MagicMock(task_id="123"), - mock.MagicMock(task_id="456"), - ] - ) + fetcher = BulkStateFetcher() + result = fetcher.get_many( + [ + mock.MagicMock(task_id="123"), + mock.MagicMock(task_id="456"), + ] + ) assert result == {'123': ('SUCCESS', None), '456': ("PENDING", None)} + assert [ + 'DEBUG:airflow.executors.celery_executor.BulkStateFetcher:Fetched 2 state(s) for 2 task(s)' + ] == cm.output @pytest.mark.integration("redis") @pytest.mark.integration("rabbitmq") @@ -461,7 +471,9 @@ class TestBulkStateFetcher(unittest.TestCase): with _prepare_app(): mock_backend = mock.MagicMock(autospec=BaseBackend) - with mock.patch.object(celery_executor.app, 'backend', mock_backend): + with mock.patch.object(celery_executor.app, 'backend', mock_backend), self.assertLogs( + "airflow.executors.celery_executor.BulkStateFetcher", level="DEBUG" + ) as cm: fetcher = BulkStateFetcher(1) result = fetcher.get_many( [ @@ -471,3 +483,6 @@ class TestBulkStateFetcher(unittest.TestCase): ) assert result == {'123': ('SUCCESS', None), '456': ("PENDING", None)} + assert [ + 'DEBUG:airflow.executors.celery_executor.BulkStateFetcher:Fetched 2 state(s) for 2 task(s)' + ] == cm.output
