uranusjr commented on code in PR #49752:
URL: https://github.com/apache/airflow/pull/49752#discussion_r2065887201


##########
airflow-core/src/airflow/api/common/mark_tasks.py:
##########
@@ -215,17 +215,30 @@ def set_dag_run_state_to_success(
     if not run_id:
         raise ValueError(f"Invalid dag_run_id: {run_id}")
 
-    # Mark all task instances of the dag run to success - except for teardown 
as they need to complete work.
+    # Mark all task instances of the dag run to success - except for 
unfinished teardown as they need to complete work.
     normal_tasks = [task for task in dag.tasks if not task.is_teardown]
+    teardown_tasks = [task for task in dag.tasks if task.is_teardown]
+    unfinished_teardown_tis: list[TaskInstance] = session.scalars(
+        select(TaskInstance).where(
+            TaskInstance.dag_id == dag.dag_id,
+            TaskInstance.run_id == run_id,
+            TaskInstance.task_id.in_([task.task_id for task in 
teardown_tasks]),
+            or_(TaskInstance.state.is_(None), 
TaskInstance.state.in_(State.unfinished)),
+        )
+    ).all()

Review Comment:
   We only really needs the task_id here, so let’s just select that field 
instead of the entire TaskInstance object. Also putting the task_ids in a set 
(instead of list) would be a good idea.



##########
airflow-core/tests/unit/api/common/test_mark_tasks.py:
##########
@@ -54,23 +56,61 @@ def test_set_dag_run_state_to_failed(dag_maker: DagMaker):
     assert "teardown" not in task_dict
 
 
-def test_set_dag_run_state_to_success(dag_maker: DagMaker):
[email protected](
+    "unfinished_state", sorted([state for state in State.unfinished if state 
is not None])
+)
+def test_set_dag_run_state_to_success_unfinished_teardown(dag_maker: DagMaker, 
unfinished_state):
     with dag_maker("TEST_DAG_1"):
         with EmptyOperator(task_id="teardown").as_teardown():
             EmptyOperator(task_id="running")
             EmptyOperator(task_id="pending")
+
     dr = dag_maker.create_dagrun()
     for ti in dr.get_task_instances():
         if ti.task_id == "running":
             ti.set_state(TaskInstanceState.RUNNING)
+        if ti.task_id == "teardown":
+            ti.set_state(unfinished_state)
+
     dag_maker.session.flush()
     assert dr.dag
 
     updated_tis: list[TaskInstance] = set_dag_run_state_to_success(
         dag=dr.dag, run_id=dr.run_id, commit=True, session=dag_maker.session
     )
+    run = dag_maker.session.scalar(select(DagRun).filter_by(dag_id=dr.dag_id, 
run_id=dr.run_id))
+    assert run.state != DagRunState.SUCCESS
     assert len(updated_tis) == 2
     task_dict = {ti.task_id: ti for ti in updated_tis}
     assert task_dict["running"].state == TaskInstanceState.SUCCESS
     assert task_dict["pending"].state == TaskInstanceState.SUCCESS
     assert "teardown" not in task_dict
+
+
[email protected](
+    "finished_state", sorted([state for state in State.finished if state != 
TaskInstanceState.SUCCESS])

Review Comment:
   ```suggestion
       "finished_state", sorted(state for state in State.finished if state != 
TaskInstanceState.SUCCESS)
   ```
   
   (Would be even simpler if we add SUCCESS back as mentioned in a comment 
above, just `sorted(State.finished)`)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to