This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new c3df47efc2 Add retry functionality for handling process termination
caused by database network issues (#31998)
c3df47efc2 is described below
commit c3df47efc2911706897bf577af8a475178de4b1b
Author: yiqijiu <[email protected]>
AuthorDate: Tue Jun 27 01:01:26 2023 +0800
Add retry functionality for handling process termination caused by database
network issues (#31998)
---
airflow/executors/celery_executor_utils.py | 12 +++++---
.../integration/executors/test_celery_executor.py | 33 ++++++++++++++++++++++
2 files changed, 41 insertions(+), 4 deletions(-)
diff --git a/airflow/executors/celery_executor_utils.py
b/airflow/executors/celery_executor_utils.py
index 2c8af4cf91..6330b62c6c 100644
--- a/airflow/executors/celery_executor_utils.py
+++ b/airflow/executors/celery_executor_utils.py
@@ -31,7 +31,7 @@ from typing import TYPE_CHECKING, Any, Mapping,
MutableMapping, Optional, Tuple
from celery import Celery, Task, states as celery_states
from celery.backends.base import BaseKeyValueStoreBackend
-from celery.backends.database import DatabaseBackend, Task as TaskDb,
session_cleanup
+from celery.backends.database import DatabaseBackend, Task as TaskDb, retry,
session_cleanup
from celery.result import AsyncResult
from celery.signals import import_modules as celery_import_modules
from setproctitle import setproctitle
@@ -250,15 +250,19 @@ class BulkStateFetcher(LoggingMixin):
return self._prepare_state_and_info_by_task_dict(task_ids,
task_results_by_task_id)
- def _get_many_from_db_backend(self, async_tasks) -> Mapping[str,
EventBufferValueType]:
- task_ids = self._tasks_list_to_task_ids(async_tasks)
+ @retry
+ def _query_task_cls_from_db_backend(self, task_ids, **kwargs):
session = app.backend.ResultSession()
task_cls = getattr(app.backend, "task_cls", TaskDb)
with session_cleanup(session):
- tasks =
session.query(task_cls).filter(task_cls.task_id.in_(task_ids)).all()
+ return
session.query(task_cls).filter(task_cls.task_id.in_(task_ids)).all()
+ def _get_many_from_db_backend(self, async_tasks) -> Mapping[str,
EventBufferValueType]:
+ task_ids = self._tasks_list_to_task_ids(async_tasks)
+ tasks = self._query_task_cls_from_db_backend(task_ids)
task_results = [app.backend.meta_from_decoded(task.to_dict()) for task
in tasks]
task_results_by_task_id = {task_result["task_id"]: task_result for
task_result in task_results}
+
return self._prepare_state_and_info_by_task_dict(task_ids,
task_results_by_task_id)
@staticmethod
diff --git a/tests/integration/executors/test_celery_executor.py
b/tests/integration/executors/test_celery_executor.py
index 4be7633c3e..a89034d2af 100644
--- a/tests/integration/executors/test_celery_executor.py
+++ b/tests/integration/executors/test_celery_executor.py
@@ -310,6 +310,39 @@ class TestBulkStateFetcher:
assert result == {"123": ("SUCCESS", None), "456": ("PENDING", None)}
assert caplog.messages == ["Fetched 2 state(s) for 2 task(s)"]
+ @mock.patch("celery.backends.database.DatabaseBackend.ResultSession")
+ def test_should_retry_db_backend(self, mock_session, caplog):
+ caplog.set_level(logging.DEBUG, logger=self.bulk_state_fetcher_logger)
+ from sqlalchemy.exc import DatabaseError
+
+ with _prepare_app():
+ mock_backend = DatabaseBackend(app=celery_executor.app,
url="sqlite3://")
+ with
mock.patch("airflow.executors.celery_executor_utils.Celery.backend",
mock_backend):
+ caplog.clear()
+ mock_session = mock_backend.ResultSession.return_value
+ mock_retry_db_result =
mock_session.query.return_value.filter.return_value.all
+ mock_retry_db_result.return_value = [
+ mock.MagicMock(**{"to_dict.return_value": {"status":
"SUCCESS", "task_id": "123"}})
+ ]
+ mock_retry_db_result.side_effect = [
+ DatabaseError("DatabaseError", "DatabaseError",
"DatabaseError"),
+ mock_retry_db_result.return_value,
+ ]
+
+ fetcher = celery_executor_utils.BulkStateFetcher()
+ result = fetcher.get_many(
+ [
+ mock.MagicMock(task_id="123"),
+ mock.MagicMock(task_id="456"),
+ ]
+ )
+ assert mock_retry_db_result.call_count == 2
+ assert result == {"123": ("SUCCESS", None), "456": ("PENDING", None)}
+ assert caplog.messages == [
+ "Failed operation _query_task_cls_from_db_backend. Retrying 2
more times.",
+ "Fetched 2 state(s) for 2 task(s)",
+ ]
+
def test_should_support_base_backend(self, caplog):
caplog.set_level(logging.DEBUG, logger=self.bulk_state_fetcher_logger)
with _prepare_app():