This is an automated email from the ASF dual-hosted git repository.

jasonliu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 4ee07a9a1cf Fix: Use get instead of hasattr for task_result in 
BulkStateFetcher (#52839)
4ee07a9a1cf is described below

commit 4ee07a9a1cf0b57528d5c411d13fdc8d155aebdb
Author: Wei-Yu Chen <[email protected]>
AuthorDate: Sat Sep 20 00:32:29 2025 -0400

    Fix: Use get instead of hasattr for task_result in BulkStateFetcher (#52839)
    
    * Fix: Use get instead of hasattr for task_result in BulkStateFetcher
    
    * add type annotation for task_results_by_task_id
    
    * add type annotation for params in methods of BulkStateFetcher
    
    * retain type annotation only in param level
    
    * add mock value for sync_parallelism in test
---
 .../celery/executors/celery_executor_utils.py      | 26 +++++++++++++---------
 .../integration/celery/test_celery_executor.py     |  6 ++---
 2 files changed, 19 insertions(+), 13 deletions(-)

diff --git 
a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py
 
b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py
index 76202dd139f..8ccbc9b56dd 100644
--- 
a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py
+++ 
b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py
@@ -30,7 +30,7 @@ import subprocess
 import sys
 import traceback
 import warnings
-from collections.abc import Mapping, MutableMapping, Sequence
+from collections.abc import Collection, Mapping, MutableMapping, Sequence
 from concurrent.futures import ProcessPoolExecutor
 from typing import TYPE_CHECKING, Any
 
@@ -323,14 +323,14 @@ class BulkStateFetcher(LoggingMixin):
     Otherwise, multiprocessing.Pool will be used. Each task status will be 
downloaded individually.
     """
 
-    def __init__(self, sync_parallelism=None):
+    def __init__(self, sync_parallelism: int):
         super().__init__()
         self._sync_parallelism = sync_parallelism
 
-    def _tasks_list_to_task_ids(self, async_tasks) -> set[str]:
+    def _tasks_list_to_task_ids(self, async_tasks: Collection[AsyncResult]) -> 
set[str]:
         return {a.task_id for a in async_tasks}
 
-    def get_many(self, async_results) -> Mapping[str, EventBufferValueType]:
+    def get_many(self, async_results: Collection[AsyncResult]) -> Mapping[str, 
EventBufferValueType]:
         """Get status for many Celery tasks using the best method available."""
         if isinstance(app.backend, BaseKeyValueStoreBackend):
             result = self._get_many_from_kv_backend(async_results)
@@ -341,7 +341,9 @@ class BulkStateFetcher(LoggingMixin):
         self.log.debug("Fetched %d state(s) for %d task(s)", len(result), 
len(async_results))
         return result
 
-    def _get_many_from_kv_backend(self, async_tasks) -> Mapping[str, 
EventBufferValueType]:
+    def _get_many_from_kv_backend(
+        self, async_tasks: Collection[AsyncResult]
+    ) -> Mapping[str, EventBufferValueType]:
         task_ids = self._tasks_list_to_task_ids(async_tasks)
         keys = [app.backend.get_key_for_task(k) for k in task_ids]
         values = app.backend.mget(keys)
@@ -351,13 +353,15 @@ class BulkStateFetcher(LoggingMixin):
         return self._prepare_state_and_info_by_task_dict(task_ids, 
task_results_by_task_id)
 
     @retry
-    def _query_task_cls_from_db_backend(self, task_ids, **kwargs):
+    def _query_task_cls_from_db_backend(self, task_ids: set[str], **kwargs):
         session = app.backend.ResultSession()
         task_cls = getattr(app.backend, "task_cls", TaskDb)
         with session_cleanup(session):
             return 
session.scalars(select(task_cls).where(task_cls.task_id.in_(task_ids))).all()
 
-    def _get_many_from_db_backend(self, async_tasks) -> Mapping[str, 
EventBufferValueType]:
+    def _get_many_from_db_backend(
+        self, async_tasks: Collection[AsyncResult]
+    ) -> Mapping[str, EventBufferValueType]:
         task_ids = self._tasks_list_to_task_ids(async_tasks)
         tasks = self._query_task_cls_from_db_backend(task_ids)
         task_results = [app.backend.meta_from_decoded(task.to_dict()) for task 
in tasks]
@@ -367,21 +371,23 @@ class BulkStateFetcher(LoggingMixin):
 
     @staticmethod
     def _prepare_state_and_info_by_task_dict(
-        task_ids, task_results_by_task_id
+        task_ids: set[str], task_results_by_task_id: dict[str, dict[str, Any]]
     ) -> Mapping[str, EventBufferValueType]:
         state_info: MutableMapping[str, EventBufferValueType] = {}
         for task_id in task_ids:
             task_result = task_results_by_task_id.get(task_id)
             if task_result:
                 state = task_result["status"]
-                info = None if not hasattr(task_result, "info") else 
task_result["info"]
+                info = task_result.get("info")
             else:
                 state = celery_states.PENDING
                 info = None
             state_info[task_id] = state, info
         return state_info
 
-    def _get_many_using_multiprocessing(self, async_results) -> Mapping[str, 
EventBufferValueType]:
+    def _get_many_using_multiprocessing(
+        self, async_results: Collection[AsyncResult]
+    ) -> Mapping[str, EventBufferValueType]:
         num_process = min(len(async_results), self._sync_parallelism)
 
         with ProcessPoolExecutor(max_workers=num_process) as sync_pool:
diff --git a/providers/celery/tests/integration/celery/test_celery_executor.py 
b/providers/celery/tests/integration/celery/test_celery_executor.py
index 874fbce3cbc..d16a3435810 100644
--- a/providers/celery/tests/integration/celery/test_celery_executor.py
+++ b/providers/celery/tests/integration/celery/test_celery_executor.py
@@ -335,7 +335,7 @@ class TestBulkStateFetcher:
                 
"airflow.providers.celery.executors.celery_executor_utils.Celery.backend", 
mock_backend
             ):
                 caplog.clear()
-                fetcher = celery_executor_utils.BulkStateFetcher()
+                fetcher = celery_executor_utils.BulkStateFetcher(1)
                 result = fetcher.get_many(
                     [
                         mock.MagicMock(task_id="123"),
@@ -367,7 +367,7 @@ class TestBulkStateFetcher:
                     mock.MagicMock(**{"to_dict.return_value": {"status": 
"SUCCESS", "task_id": "123"}})
                 ]
 
-                fetcher = celery_executor_utils.BulkStateFetcher()
+                fetcher = celery_executor_utils.BulkStateFetcher(1)
                 result = fetcher.get_many(
                     [
                         mock.MagicMock(task_id="123"),
@@ -401,7 +401,7 @@ class TestBulkStateFetcher:
                     mock_retry_db_result.return_value,
                 ]
 
-                fetcher = celery_executor_utils.BulkStateFetcher()
+                fetcher = celery_executor_utils.BulkStateFetcher(1)
                 result = fetcher.get_many(
                     [
                         mock.MagicMock(task_id="123"),

Reply via email to