This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new c3df47efc2 Add retry functionality for handling process termination 
caused by database network issues (#31998)
c3df47efc2 is described below

commit c3df47efc2911706897bf577af8a475178de4b1b
Author: yiqijiu <[email protected]>
AuthorDate: Tue Jun 27 01:01:26 2023 +0800

    Add retry functionality for handling process termination caused by database 
network issues (#31998)
---
 airflow/executors/celery_executor_utils.py         | 12 +++++---
 .../integration/executors/test_celery_executor.py  | 33 ++++++++++++++++++++++
 2 files changed, 41 insertions(+), 4 deletions(-)

diff --git a/airflow/executors/celery_executor_utils.py 
b/airflow/executors/celery_executor_utils.py
index 2c8af4cf91..6330b62c6c 100644
--- a/airflow/executors/celery_executor_utils.py
+++ b/airflow/executors/celery_executor_utils.py
@@ -31,7 +31,7 @@ from typing import TYPE_CHECKING, Any, Mapping, 
MutableMapping, Optional, Tuple
 
 from celery import Celery, Task, states as celery_states
 from celery.backends.base import BaseKeyValueStoreBackend
-from celery.backends.database import DatabaseBackend, Task as TaskDb, 
session_cleanup
+from celery.backends.database import DatabaseBackend, Task as TaskDb, retry, 
session_cleanup
 from celery.result import AsyncResult
 from celery.signals import import_modules as celery_import_modules
 from setproctitle import setproctitle
@@ -250,15 +250,19 @@ class BulkStateFetcher(LoggingMixin):
 
         return self._prepare_state_and_info_by_task_dict(task_ids, 
task_results_by_task_id)
 
-    def _get_many_from_db_backend(self, async_tasks) -> Mapping[str, 
EventBufferValueType]:
-        task_ids = self._tasks_list_to_task_ids(async_tasks)
+    @retry
+    def _query_task_cls_from_db_backend(self, task_ids, **kwargs):
         session = app.backend.ResultSession()
         task_cls = getattr(app.backend, "task_cls", TaskDb)
         with session_cleanup(session):
-            tasks = 
session.query(task_cls).filter(task_cls.task_id.in_(task_ids)).all()
+            return 
session.query(task_cls).filter(task_cls.task_id.in_(task_ids)).all()
 
+    def _get_many_from_db_backend(self, async_tasks) -> Mapping[str, 
EventBufferValueType]:
+        task_ids = self._tasks_list_to_task_ids(async_tasks)
+        tasks = self._query_task_cls_from_db_backend(task_ids)
         task_results = [app.backend.meta_from_decoded(task.to_dict()) for task 
in tasks]
         task_results_by_task_id = {task_result["task_id"]: task_result for 
task_result in task_results}
+
         return self._prepare_state_and_info_by_task_dict(task_ids, 
task_results_by_task_id)
 
     @staticmethod
diff --git a/tests/integration/executors/test_celery_executor.py 
b/tests/integration/executors/test_celery_executor.py
index 4be7633c3e..a89034d2af 100644
--- a/tests/integration/executors/test_celery_executor.py
+++ b/tests/integration/executors/test_celery_executor.py
@@ -310,6 +310,39 @@ class TestBulkStateFetcher:
         assert result == {"123": ("SUCCESS", None), "456": ("PENDING", None)}
         assert caplog.messages == ["Fetched 2 state(s) for 2 task(s)"]
 
+    @mock.patch("celery.backends.database.DatabaseBackend.ResultSession")
+    def test_should_retry_db_backend(self, mock_session, caplog):
+        caplog.set_level(logging.DEBUG, logger=self.bulk_state_fetcher_logger)
+        from sqlalchemy.exc import DatabaseError
+
+        with _prepare_app():
+            mock_backend = DatabaseBackend(app=celery_executor.app, 
url="sqlite3://")
+            with 
mock.patch("airflow.executors.celery_executor_utils.Celery.backend", 
mock_backend):
+                caplog.clear()
+                mock_session = mock_backend.ResultSession.return_value
+                mock_retry_db_result = 
mock_session.query.return_value.filter.return_value.all
+                mock_retry_db_result.return_value = [
+                    mock.MagicMock(**{"to_dict.return_value": {"status": 
"SUCCESS", "task_id": "123"}})
+                ]
+                mock_retry_db_result.side_effect = [
+                    DatabaseError("DatabaseError", "DatabaseError", 
"DatabaseError"),
+                    mock_retry_db_result.return_value,
+                ]
+
+                fetcher = celery_executor_utils.BulkStateFetcher()
+                result = fetcher.get_many(
+                    [
+                        mock.MagicMock(task_id="123"),
+                        mock.MagicMock(task_id="456"),
+                    ]
+                )
+        assert mock_retry_db_result.call_count == 2
+        assert result == {"123": ("SUCCESS", None), "456": ("PENDING", None)}
+        assert caplog.messages == [
+            "Failed operation _query_task_cls_from_db_backend.  Retrying 2 
more times.",
+            "Fetched 2 state(s) for 2 task(s)",
+        ]
+
     def test_should_support_base_backend(self, caplog):
         caplog.set_level(logging.DEBUG, logger=self.bulk_state_fetcher_logger)
         with _prepare_app():

Reply via email to