uranusjr commented on code in PR #49752:
URL: https://github.com/apache/airflow/pull/49752#discussion_r2065887201
##########
airflow-core/src/airflow/api/common/mark_tasks.py:
##########
@@ -215,17 +215,30 @@ def set_dag_run_state_to_success(
if not run_id:
raise ValueError(f"Invalid dag_run_id: {run_id}")
- # Mark all task instances of the dag run to success - except for teardown
as they need to complete work.
+ # Mark all task instances of the dag run to success - except for
unfinished teardown as they need to complete work.
normal_tasks = [task for task in dag.tasks if not task.is_teardown]
+ teardown_tasks = [task for task in dag.tasks if task.is_teardown]
+ unfinished_teardown_tis: list[TaskInstance] = session.scalars(
+ select(TaskInstance).where(
+ TaskInstance.dag_id == dag.dag_id,
+ TaskInstance.run_id == run_id,
+ TaskInstance.task_id.in_([task.task_id for task in
teardown_tasks]),
+ or_(TaskInstance.state.is_(None),
TaskInstance.state.in_(State.unfinished)),
+ )
+ ).all()
Review Comment:
We only really needs the task_id here, so let’s just select that field
instead of the entire TaskInstance object. Also putting the task_ids in a set
(instead of list) would be a good idea.
##########
airflow-core/tests/unit/api/common/test_mark_tasks.py:
##########
@@ -54,23 +56,61 @@ def test_set_dag_run_state_to_failed(dag_maker: DagMaker):
assert "teardown" not in task_dict
-def test_set_dag_run_state_to_success(dag_maker: DagMaker):
[email protected](
+ "unfinished_state", sorted([state for state in State.unfinished if state
is not None])
+)
+def test_set_dag_run_state_to_success_unfinished_teardown(dag_maker: DagMaker,
unfinished_state):
with dag_maker("TEST_DAG_1"):
with EmptyOperator(task_id="teardown").as_teardown():
EmptyOperator(task_id="running")
EmptyOperator(task_id="pending")
+
dr = dag_maker.create_dagrun()
for ti in dr.get_task_instances():
if ti.task_id == "running":
ti.set_state(TaskInstanceState.RUNNING)
+ if ti.task_id == "teardown":
+ ti.set_state(unfinished_state)
+
dag_maker.session.flush()
assert dr.dag
updated_tis: list[TaskInstance] = set_dag_run_state_to_success(
dag=dr.dag, run_id=dr.run_id, commit=True, session=dag_maker.session
)
+ run = dag_maker.session.scalar(select(DagRun).filter_by(dag_id=dr.dag_id,
run_id=dr.run_id))
+ assert run.state != DagRunState.SUCCESS
assert len(updated_tis) == 2
task_dict = {ti.task_id: ti for ti in updated_tis}
assert task_dict["running"].state == TaskInstanceState.SUCCESS
assert task_dict["pending"].state == TaskInstanceState.SUCCESS
assert "teardown" not in task_dict
+
+
[email protected](
+ "finished_state", sorted([state for state in State.finished if state !=
TaskInstanceState.SUCCESS])
Review Comment:
```suggestion
"finished_state", sorted(state for state in State.finished if state !=
TaskInstanceState.SUCCESS)
```
(Would be even simpler if we add SUCCESS back as mentioned in a comment
above, just `sorted(State.finished)`)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]