This is an automated email from the ASF dual-hosted git repository.

rahulvats pushed a commit to branch py-client-sync
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 2058516acce1e33b3f15790b51023cc216d4e4e0
Author: Elad Kalif <[email protected]>
AuthorDate: Tue Mar 24 18:37:47 2026 +0200

    Make test_celery_integration runnable (#64153)
---
 .../integration/celery/test_celery_executor.py     | 83 +++++++++++-----------
 1 file changed, 41 insertions(+), 42 deletions(-)

diff --git a/providers/celery/tests/integration/celery/test_celery_executor.py 
b/providers/celery/tests/integration/celery/test_celery_executor.py
index 5e05c74a7d6..3641abb6b1f 100644
--- a/providers/celery/tests/integration/celery/test_celery_executor.py
+++ b/providers/celery/tests/integration/celery/test_celery_executor.py
@@ -37,10 +37,10 @@ from celery.backends.database import DatabaseBackend
 from celery.contrib.testing.worker import start_worker
 from kombu.asynchronous import set_event_loop
 from kubernetes.client import models as k8s
-from uuid6 import uuid7
 
 from airflow._shared.timezones import timezone
 from airflow.executors import workloads
+from airflow.executors.workloads.base import BundleInfo
 from airflow.executors.workloads.task import TaskInstanceDTO
 from airflow.models.dag import DAG
 from airflow.models.taskinstance import TaskInstance
@@ -128,17 +128,6 @@ class TestCeleryExecutor:
         db.clear_db_runs()
         db.clear_db_jobs()
 
-
-def setup_dagrun_with_success_and_fail_tasks(dag_maker):
-    date = timezone.utcnow()
-    start_date = date - timedelta(days=2)
-
-    with dag_maker("test_celery_integration"):
-        BaseOperator(task_id="success", start_date=start_date)
-        BaseOperator(task_id="fail", start_date=start_date)
-
-    return dag_maker.create_dagrun(logical_date=date)
-
     @pytest.mark.flaky(reruns=5, reruns_delay=3)
     @pytest.mark.parametrize("broker_url", _prepare_test_bodies())
     @pytest.mark.parametrize(
@@ -171,8 +160,8 @@ def setup_dagrun_with_success_and_fail_tasks(dag_maker):
             ),
         ],
     )
-    def test_celery_integration(self, broker_url, executor_config):
-        from airflow.providers.celery.executors import celery_executor, 
celery_executor_utils
+    def test_celery_integration(self, broker_url, executor_config, dag_maker):
+        from airflow.providers.celery.executors import celery_executor
 
         if AIRFLOW_V_3_0_PLUS:
             # Airflow 3: execute_workload receives JSON string
@@ -193,32 +182,42 @@ def setup_dagrun_with_success_and_fail_tasks(dag_maker):
 
         with _prepare_app(broker_url, execute=fake_execute) as app:
             executor = celery_executor.CeleryExecutor()
+            # Force single-process sending so mock patches survive 
(ProcessPoolExecutor
+            # would fork new processes where the patches are not active).
+            executor._sync_parallelism = 1
             assert executor.tasks == {}
             executor.start()
 
             with start_worker(app=app, logfile=sys.stdout, loglevel="info"):
-                ti_success = TaskInstanceDTO.model_construct(
-                    id=uuid7(),
-                    task_id="success",
-                    dag_id="id",
-                    run_id="abc",
-                    try_number=0,
-                    priority_weight=1,
-                    
queue=celery_executor_utils.get_celery_configuration()["task_default_queue"],
-                    executor_config=executor_config,
+                dagrun_date = timezone.utcnow()
+                dagrun_start = dagrun_date - timedelta(days=2)
+                with dag_maker("test_celery_integration"):
+                    BaseOperator(task_id="success", start_date=dagrun_start)
+                    BaseOperator(task_id="fail", start_date=dagrun_start)
+                dagrun = dag_maker.create_dagrun(logical_date=dagrun_date)
+                ti_fail, ti_success = sorted(dagrun.task_instances, key=lambda 
ti: ti.task_id)
+                # Derive keys from the real task instances so they match what 
the executor tracks
+                key_fail = TaskInstanceKey(
+                    ti_fail.dag_id, ti_fail.task_id, ti_fail.run_id, 
ti_fail.try_number, ti_fail.map_index
                 )
-                keys = [
-                    TaskInstanceKey("id", "success", "abc", 0, -1),
-                    TaskInstanceKey("id", "fail", "abc", 0, -1),
-                ]
-                dagrun = setup_dagrun_with_success_and_fail_tasks(dag_maker)
-                ti_success, ti_fail = dagrun.task_instances
-                for w in (
-                    workloads.ExecuteTask.make(
-                        ti=ti_success,
-                    ),
-                    workloads.ExecuteTask.make(ti=ti_fail),
-                ):
+                key_success = TaskInstanceKey(
+                    ti_success.dag_id,
+                    ti_success.task_id,
+                    ti_success.run_id,
+                    ti_success.try_number,
+                    ti_success.map_index,
+                )
+                keys = [key_fail, key_success]
+                for ti in (ti_success, ti_fail):
+                    ti_dto = TaskInstanceDTO.model_validate(ti, 
from_attributes=True)
+                    ti_dto.executor_config = executor_config
+                    w = workloads.ExecuteTask(
+                        ti=ti_dto,
+                        dag_rel_path="test.py",
+                        token="",
+                        bundle_info=BundleInfo(name="test"),
+                        log_path="test.log",
+                    )
                     executor.queue_workload(w, session=None)
 
                 executor.trigger_tasks(open_slots=10)
@@ -231,17 +230,17 @@ def setup_dagrun_with_success_and_fail_tasks(dag_maker):
                         num_tasks,
                     )
                     sleep(0.4)
-                assert list(executor.tasks.keys()) == keys
-                assert executor.event_buffer[keys[0]][0] == State.QUEUED
-                assert executor.event_buffer[keys[1]][0] == State.QUEUED
+                assert sorted(executor.tasks.keys()) == sorted(keys)
+                assert executor.event_buffer[key_success][0] == State.QUEUED
+                assert executor.event_buffer[key_fail][0] == State.QUEUED
 
                 executor.end(synchronous=True)
 
-        assert executor.event_buffer[keys[0]][0] == State.SUCCESS
-        assert executor.event_buffer[keys[1]][0] == State.FAILED
+        assert executor.event_buffer[key_success][0] == State.SUCCESS
+        assert executor.event_buffer[key_fail][0] == State.FAILED
 
-        assert keys[0] not in executor.tasks
-        assert keys[1] not in executor.tasks
+        assert key_success not in executor.tasks
+        assert key_fail not in executor.tasks
 
         assert executor.queued_tasks == {}
 

Reply via email to