ashb commented on a change in pull request #16301:
URL: https://github.com/apache/airflow/pull/16301#discussion_r649180426
##########
File path: tests/jobs/test_local_task_job.py
##########
@@ -571,6 +572,79 @@ def task_function(ti):
assert task_terminated_externally.value == 1
assert not process.is_alive()
+ @parameterized.expand(
+ [
+ (signal.SIGTERM,),
+ (signal.SIGKILL,),
+ ]
+ )
+ def test_process_kill_calls_works_with_retries(self, signal_type):
+ """
+ Test that ensures that tasks are set for up-for-retry when they receive
+ sigkill or sigterm and failure_callback is not called on getting a
sigterm
+ """
+ # use shared memory value so we can properly track value change even if
+ # it's been updated across processes.
+ failure_callback_called = Value('i', 0)
+ task_terminated_externally = Value('i', 1)
+ shared_mem_lock = Lock()
+
+ def failure_callback(context):
+ with shared_mem_lock:
+ failure_callback_called.value += 1
+ assert context['dag_run'].dag_id == 'test_mark_failure_2'
+
+ dag = DAG(dag_id='test_mark_failure_2', start_date=DEFAULT_DATE,
default_args={'owner': 'owner1'})
+
+ def task_function(ti):
+ # pylint: disable=unused-argument
+ time.sleep(60)
+ # This should not happen -- the state change should be noticed and
the task should get killed
+ with shared_mem_lock:
+ task_terminated_externally.value = 0
+
+ task = PythonOperator(
+ task_id='test_on_failure',
+ python_callable=task_function,
+ retries=1,
+ retry_delay=timedelta(seconds=3),
+ on_failure_callback=failure_callback,
+ dag=dag,
+ )
+
+ session = settings.Session()
+
+ dag.clear()
+ dag.create_dagrun(
+ run_id="test",
+ state=State.RUNNING,
+ execution_date=DEFAULT_DATE,
+ start_date=DEFAULT_DATE,
+ session=session,
+ )
+ ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
+ ti.refresh_from_db()
+ job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True,
executor=SequentialExecutor())
+ job1.task_runner = StandardTaskRunner(job1)
+
+ settings.engine.dispose()
+ process = multiprocessing.Process(target=job1.run)
+ process.start()
+
+ for _ in range(0, 20):
+ ti.refresh_from_db()
+ if ti.state == State.RUNNING and ti.pid is not None:
+ break
+ time.sleep(0.2)
+ assert ti.state == State.RUNNING
+ assert ti.pid is not None
+ os.kill(ti.pid, signal_type)
Review comment:
I think this is killing the "wrong" process, and is not testing the new
code you wrote.
You want to send the signal to `process` -- and `SIGKILL` cannot be caught,
so you don't need to test that.
(It's only working now because you are sending it to one process further
down, not the one running the LocalTaskJob code you have edited)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]