This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 706bd12a9ff Fix MyPy type errors in mark_tasks.py for Sqlalchemy 2 
migration (#57271)
706bd12a9ff is described below

commit 706bd12a9ff4cf743ef2fd4015f69b8cea38e5e5
Author: Anusha Kovi <[email protected]>
AuthorDate: Mon Oct 27 19:00:17 2025 +0530

    Fix MyPy type errors in mark_tasks.py for Sqlalchemy 2 migration (#57271)
---
 airflow-core/src/airflow/api/common/mark_tasks.py | 54 +++++++++++++----------
 1 file changed, 30 insertions(+), 24 deletions(-)

diff --git a/airflow-core/src/airflow/api/common/mark_tasks.py 
b/airflow-core/src/airflow/api/common/mark_tasks.py
index 5c0ed4b9f5f..0e1d8610ddf 100644
--- a/airflow-core/src/airflow/api/common/mark_tasks.py
+++ b/airflow-core/src/airflow/api/common/mark_tasks.py
@@ -20,7 +20,7 @@
 from __future__ import annotations
 
 from collections.abc import Collection, Iterable
-from typing import TYPE_CHECKING, TypeAlias
+from typing import TYPE_CHECKING, TypeAlias, cast
 
 from sqlalchemy import and_, or_, select
 from sqlalchemy.orm import lazyload
@@ -32,6 +32,7 @@ from airflow.utils.state import DagRunState, State, 
TaskInstanceState
 
 if TYPE_CHECKING:
     from sqlalchemy.orm import Session as SASession
+    from sqlalchemy.sql import ColumnElement
 
     from airflow.models.mappedoperator import MappedOperator
     from airflow.serialization.serialized_objects import 
SerializedBaseOperator, SerializedDAG
@@ -92,12 +93,12 @@ def set_state(
     qry_dag = get_all_dag_task_query(dag, state, task_id_map_index_list, 
dag_run_ids)
 
     if commit:
-        tis_altered = session.scalars(qry_dag.with_for_update()).all()
+        tis_altered = list(session.scalars(qry_dag.with_for_update()).all())
         for task_instance in tis_altered:
             task_instance.set_state(state, session=session)
         session.flush()
     else:
-        tis_altered = session.scalars(qry_dag).all()
+        tis_altered = list(session.scalars(qry_dag).all())
     return tis_altered
 
 
@@ -111,8 +112,9 @@ def get_all_dag_task_query(
     qry_dag = select(TaskInstance).where(
         TaskInstance.dag_id == dag.dag_id,
         TaskInstance.run_id.in_(run_ids),
-        TaskInstance.ti_selector_condition(task_ids),
     )
+    # Apply ti_selector_condition separately to handle type issues
+    qry_dag = qry_dag.where(cast("ColumnElement[bool]", 
TaskInstance.ti_selector_condition(task_ids)))
 
     qry_dag = qry_dag.where(or_(TaskInstance.state.is_(None), 
TaskInstance.state != state)).options(
         lazyload(TaskInstance.dag_run)
@@ -294,14 +296,16 @@ def set_dag_run_state_to_failed(
 
     # Mark only RUNNING task instances.
     task_ids = [task.task_id for task in dag.tasks]
-    running_tis: list[TaskInstance] = session.scalars(
-        select(TaskInstance).where(
-            TaskInstance.dag_id == dag.dag_id,
-            TaskInstance.run_id == run_id,
-            TaskInstance.task_id.in_(task_ids),
-            TaskInstance.state.in_(running_states),
-        )
-    ).all()
+    running_tis: list[TaskInstance] = list(
+        session.scalars(
+            select(TaskInstance).where(
+                TaskInstance.dag_id == dag.dag_id,
+                TaskInstance.run_id == run_id,
+                TaskInstance.task_id.in_(task_ids),
+                TaskInstance.state.in_(running_states),
+            )
+        ).all()
+    )
 
     # Do not kill teardown tasks
     task_ids_of_running_tis = {ti.task_id for ti in running_tis if not 
dag.task_dict[ti.task_id].is_teardown}
@@ -313,19 +317,21 @@ def set_dag_run_state_to_failed(
     running_tasks = [_set_runing_task(task) for task in dag.tasks if 
task.task_id in task_ids_of_running_tis]
 
     # Mark non-finished tasks as SKIPPED.
-    pending_tis: list[TaskInstance] = session.scalars(
-        select(TaskInstance).filter(
-            TaskInstance.dag_id == dag.dag_id,
-            TaskInstance.run_id == run_id,
-            or_(
-                TaskInstance.state.is_(None),
-                and_(
-                    TaskInstance.state.not_in(State.finished),
-                    TaskInstance.state.not_in(running_states),
+    pending_tis: list[TaskInstance] = list(
+        session.scalars(
+            select(TaskInstance).filter(
+                TaskInstance.dag_id == dag.dag_id,
+                TaskInstance.run_id == run_id,
+                or_(
+                    TaskInstance.state.is_(None),
+                    and_(
+                        TaskInstance.state.not_in(State.finished),
+                        TaskInstance.state.not_in(running_states),
+                    ),
                 ),
-            ),
-        )
-    ).all()
+            )
+        ).all()
+    )
 
     # Do not skip teardown tasks
     pending_normal_tis = [ti for ti in pending_tis if not 
dag.task_dict[ti.task_id].is_teardown]

Reply via email to