ashb commented on code in PR #58365:
URL: https://github.com/apache/airflow/pull/58365#discussion_r2560055767


##########
airflow-core/tests/unit/executors/test_local_executor.py:
##########
@@ -158,16 +173,73 @@ def test_clean_stop_on_signal(self):
         executor = LocalExecutor(parallelism=2)
         executor.start()
 
-        # We want to ensure we start a worker process, as we now only create 
them on demand
-        executor._spawn_worker()
-
         try:
             os.kill(os.getpid(), signal.SIGINT)
         except KeyboardInterrupt:
             pass
         finally:
             executor.end()
 
+    @skip_spawn_mp_start
+    def test_worker_process_revive(self):
+        executor = LocalExecutor(parallelism=2)
+        executor.start()
+
+        # Mock the process to make it appear dead.
+        # However, the processes that lost their references must be included 
in end() before termination.
+        # Otherwise, the test will not finish and a timeout will occur.
+        dead_process = {}
+
+        for killed_pid, killed_proc in executor.workers.items():
+            proc = mock.MagicMock()
+            proc.is_alive.return_value = False
+
+            dead_process[killed_pid] = killed_proc
+            executor.workers[killed_pid] = proc
+
+        success_tis = [
+            workloads.TaskInstance(
+                id=uuid7(),
+                dag_version_id=uuid7(),
+                task_id=f"success_{i}",
+                dag_id="mydag",
+                run_id="run1",
+                try_number=1,
+                state="queued",
+                pool_slots=1,
+                queue="default",
+                priority_weight=1,
+                map_index=-1,
+                start_date=timezone.utcnow(),
+            )
+            for i in range(self.TEST_SUCCESS_COMMANDS)
+        ]
+
+        for ti in success_tis:
+            executor.queue_workload(
+                workloads.ExecuteTask(
+                    token="",
+                    ti=ti,
+                    dag_rel_path="some/path",
+                    log_path=None,
+                    bundle_info=dict(name="hi", version="hi"),
+                ),
+                session=mock.MagicMock(spec=Session),
+            )
+
+        with spy_on(executor._spawn_worker) as spawn_worker:
+            executor._process_workloads(list(executor.queued_tasks.values()))
+
+            if executor.is_mp_using_fork:
+                assert len(spawn_worker.calls) == 2
+            else:
+                assert len(spawn_worker.calls) == 1

Review Comment:
   We ran into timing based approaches before on CI, so I think this is safer. 
When this sort of assert was first added, it too was `== 2`, but I think it's 
generally safer to be this no mater the spawn method
   
   ```suggestion
               assert 1 <= len(spawn_worker.calls) <= 2
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to