This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 706bd12a9ff Fix MyPy type errors in mark_tasks.py for Sqlalchemy 2
migration (#57271)
706bd12a9ff is described below
commit 706bd12a9ff4cf743ef2fd4015f69b8cea38e5e5
Author: Anusha Kovi <[email protected]>
AuthorDate: Mon Oct 27 19:00:17 2025 +0530
Fix MyPy type errors in mark_tasks.py for Sqlalchemy 2 migration (#57271)
---
airflow-core/src/airflow/api/common/mark_tasks.py | 54 +++++++++++++----------
1 file changed, 30 insertions(+), 24 deletions(-)
diff --git a/airflow-core/src/airflow/api/common/mark_tasks.py
b/airflow-core/src/airflow/api/common/mark_tasks.py
index 5c0ed4b9f5f..0e1d8610ddf 100644
--- a/airflow-core/src/airflow/api/common/mark_tasks.py
+++ b/airflow-core/src/airflow/api/common/mark_tasks.py
@@ -20,7 +20,7 @@
from __future__ import annotations
from collections.abc import Collection, Iterable
-from typing import TYPE_CHECKING, TypeAlias
+from typing import TYPE_CHECKING, TypeAlias, cast
from sqlalchemy import and_, or_, select
from sqlalchemy.orm import lazyload
@@ -32,6 +32,7 @@ from airflow.utils.state import DagRunState, State,
TaskInstanceState
if TYPE_CHECKING:
from sqlalchemy.orm import Session as SASession
+ from sqlalchemy.sql import ColumnElement
from airflow.models.mappedoperator import MappedOperator
from airflow.serialization.serialized_objects import
SerializedBaseOperator, SerializedDAG
@@ -92,12 +93,12 @@ def set_state(
qry_dag = get_all_dag_task_query(dag, state, task_id_map_index_list,
dag_run_ids)
if commit:
- tis_altered = session.scalars(qry_dag.with_for_update()).all()
+ tis_altered = list(session.scalars(qry_dag.with_for_update()).all())
for task_instance in tis_altered:
task_instance.set_state(state, session=session)
session.flush()
else:
- tis_altered = session.scalars(qry_dag).all()
+ tis_altered = list(session.scalars(qry_dag).all())
return tis_altered
@@ -111,8 +112,9 @@ def get_all_dag_task_query(
qry_dag = select(TaskInstance).where(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.run_id.in_(run_ids),
- TaskInstance.ti_selector_condition(task_ids),
)
+ # Apply ti_selector_condition separately to handle type issues
+ qry_dag = qry_dag.where(cast("ColumnElement[bool]",
TaskInstance.ti_selector_condition(task_ids)))
qry_dag = qry_dag.where(or_(TaskInstance.state.is_(None),
TaskInstance.state != state)).options(
lazyload(TaskInstance.dag_run)
@@ -294,14 +296,16 @@ def set_dag_run_state_to_failed(
# Mark only RUNNING task instances.
task_ids = [task.task_id for task in dag.tasks]
- running_tis: list[TaskInstance] = session.scalars(
- select(TaskInstance).where(
- TaskInstance.dag_id == dag.dag_id,
- TaskInstance.run_id == run_id,
- TaskInstance.task_id.in_(task_ids),
- TaskInstance.state.in_(running_states),
- )
- ).all()
+ running_tis: list[TaskInstance] = list(
+ session.scalars(
+ select(TaskInstance).where(
+ TaskInstance.dag_id == dag.dag_id,
+ TaskInstance.run_id == run_id,
+ TaskInstance.task_id.in_(task_ids),
+ TaskInstance.state.in_(running_states),
+ )
+ ).all()
+ )
# Do not kill teardown tasks
task_ids_of_running_tis = {ti.task_id for ti in running_tis if not
dag.task_dict[ti.task_id].is_teardown}
@@ -313,19 +317,21 @@ def set_dag_run_state_to_failed(
running_tasks = [_set_runing_task(task) for task in dag.tasks if
task.task_id in task_ids_of_running_tis]
# Mark non-finished tasks as SKIPPED.
- pending_tis: list[TaskInstance] = session.scalars(
- select(TaskInstance).filter(
- TaskInstance.dag_id == dag.dag_id,
- TaskInstance.run_id == run_id,
- or_(
- TaskInstance.state.is_(None),
- and_(
- TaskInstance.state.not_in(State.finished),
- TaskInstance.state.not_in(running_states),
+ pending_tis: list[TaskInstance] = list(
+ session.scalars(
+ select(TaskInstance).filter(
+ TaskInstance.dag_id == dag.dag_id,
+ TaskInstance.run_id == run_id,
+ or_(
+ TaskInstance.state.is_(None),
+ and_(
+ TaskInstance.state.not_in(State.finished),
+ TaskInstance.state.not_in(running_states),
+ ),
),
- ),
- )
- ).all()
+ )
+ ).all()
+ )
# Do not skip teardown tasks
pending_normal_tis = [ti for ti in pending_tis if not
dag.task_dict[ti.task_id].is_teardown]