This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 5c3e0066aa6 Fix MyPy type errors
/api_fastapi/execution_api/,/airflow/jobs/, /airflow/models/ in Sqlalchemy 2
migration (#57277)
5c3e0066aa6 is described below
commit 5c3e0066aa6a85bd4c9695e1b9e6b5b4af9eb6f8
Author: Anusha Kovi <[email protected]>
AuthorDate: Mon Oct 27 19:12:53 2025 +0530
Fix MyPy type errors /api_fastapi/execution_api/,/airflow/jobs/,
/airflow/models/ in Sqlalchemy 2 migration (#57277)
---
.../api_fastapi/execution_api/routes/dag_runs.py | 5 ++
.../execution_api/routes/task_instances.py | 83 ++++++++++++++--------
.../src/airflow/jobs/scheduler_job_runner.py | 70 ++++++++++++------
airflow-core/src/airflow/models/taskreschedule.py | 4 +-
4 files changed, 108 insertions(+), 54 deletions(-)
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py
index 22237ee5e28..b3a99bf21a3 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py
@@ -129,6 +129,11 @@ def clear_dag_run(
dag_run = session.scalar(
select(DagRunModel).where(DagRunModel.dag_id == dag_id,
DagRunModel.run_id == run_id)
)
+ if dag_run is None:
+ raise HTTPException(
+ status.HTTP_404_NOT_FOUND,
+ detail={"reason": "not_found", "message": f"DAG run with run_id:
'{run_id}' not found"},
+ )
dag = get_dag_for_run(dag_bag, dag_run=dag_run, session=session)
dag.clear(run_id=run_id)
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index 2e81d313f0d..bf51488080c 100644
---
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -210,7 +210,7 @@ def ti_run(
try:
result = session.execute(query)
- log.info("Task instance state updated", rows_affected=result.rowcount)
+ log.info("Task instance state updated", rows_affected=getattr(result,
"rowcount", 0))
dr = (
session.scalars(
@@ -235,15 +235,15 @@ def ti_run(
xcom_keys = []
if not ti.next_method:
map_index = None if ti.map_index < 0 else ti.map_index
- query = select(XComModel.key).where(
+ xcom_query = select(XComModel.key).where(
XComModel.dag_id == ti.dag_id,
XComModel.task_id == ti.task_id,
XComModel.run_id == ti.run_id,
)
if map_index is not None:
- query = query.where(XComModel.map_index == map_index)
+ xcom_query = xcom_query.where(XComModel.map_index == map_index)
- xcom_keys = list(session.scalars(query))
+ xcom_keys = list(session.scalars(xcom_query))
task_reschedule_count = (
session.query(
func.count(TaskReschedule.id) # or any other primary key
column
@@ -399,15 +399,21 @@ def ti_update_state(
# Set a task to failed in case any unexpected exception happened
during task state update
log.exception("Error updating Task Instance state to %s. Set the task
to failed", updated_state)
ti = session.get(TI, ti_id_str)
- query = TI.duration_expression_update(datetime.now(tz=timezone.utc),
query, session.bind)
+ if session.bind is not None:
+ query =
TI.duration_expression_update(datetime.now(tz=timezone.utc), query,
session.bind)
query = query.values(state=TaskInstanceState.FAILED)
- _handle_fail_fast_for_dag(ti=ti, dag_id=dag_id, session=session,
dag_bag=dag_bag)
+ if ti is not None:
+ _handle_fail_fast_for_dag(ti=ti, dag_id=dag_id, session=session,
dag_bag=dag_bag)
# TODO: Replace this with FastAPI's Custom Exception handling:
#
https://fastapi.tiangolo.com/tutorial/handling-errors/#install-custom-exception-handlers
try:
result = session.execute(query)
- log.info("Task instance state updated", new_state=updated_state,
rows_affected=result.rowcount)
+ log.info(
+ "Task instance state updated",
+ new_state=updated_state,
+ rows_affected=getattr(result, "rowcount", 0),
+ )
except SQLAlchemyError as e:
log.error("Error updating Task Instance state", error=str(e))
raise HTTPException(
@@ -445,21 +451,25 @@ def _create_ti_state_update_query_and_update_state(
if isinstance(ti_patch_payload, (TITerminalStatePayload,
TIRetryStatePayload, TISuccessStatePayload)):
ti = session.get(TI, ti_id_str)
updated_state = ti_patch_payload.state
- query = TI.duration_expression_update(ti_patch_payload.end_date,
query, session.bind)
+ if session.bind is not None:
+ query = TI.duration_expression_update(ti_patch_payload.end_date,
query, session.bind)
query = query.values(state=updated_state, next_method=None,
next_kwargs=None)
if updated_state == TerminalTIState.FAILED:
# This is the only case needs extra handling for
TITerminalStatePayload
- _handle_fail_fast_for_dag(ti=ti, dag_id=dag_id, session=session,
dag_bag=dag_bag)
+ if ti is not None:
+ _handle_fail_fast_for_dag(ti=ti, dag_id=dag_id,
session=session, dag_bag=dag_bag)
elif isinstance(ti_patch_payload, TIRetryStatePayload):
- ti.prepare_db_for_next_try(session)
+ if ti is not None:
+ ti.prepare_db_for_next_try(session)
elif isinstance(ti_patch_payload, TISuccessStatePayload):
- TI.register_asset_changes_in_db(
- ti,
- ti_patch_payload.task_outlets, # type: ignore
- ti_patch_payload.outlet_events,
- session,
- )
+ if ti is not None:
+ TI.register_asset_changes_in_db(
+ ti,
+ ti_patch_payload.task_outlets, # type: ignore
+ ti_patch_payload.outlet_events,
+ session,
+ )
elif isinstance(ti_patch_payload, TIDeferredStatePayload):
# Calculate timeout if it was passed
timeout = None
@@ -516,26 +526,30 @@ def _create_ti_state_update_query_and_update_state(
)
data =
ti_patch_payload.model_dump(exclude={"reschedule_date"}, exclude_unset=True)
query = update(TI).where(TI.id == ti_id_str).values(data)
- query =
TI.duration_expression_update(datetime.now(tz=timezone.utc), query,
session.bind)
+ if session.bind is not None:
+ query =
TI.duration_expression_update(datetime.now(tz=timezone.utc), query,
session.bind)
query = query.values(state=TaskInstanceState.FAILED)
ti = session.get(TI, ti_id_str)
- _handle_fail_fast_for_dag(ti=ti, dag_id=dag_id,
session=session, dag_bag=dag_bag)
+ if ti is not None:
+ _handle_fail_fast_for_dag(ti=ti, dag_id=dag_id,
session=session, dag_bag=dag_bag)
return query, updated_state
task_instance = session.get(TI, ti_id_str)
actual_start_date = timezone.utcnow()
- session.add(
- TaskReschedule(
- task_instance.id,
- actual_start_date,
- ti_patch_payload.end_date,
- ti_patch_payload.reschedule_date,
+ if task_instance is not None and task_instance.id is not None:
+ session.add(
+ TaskReschedule(
+ UUID(str(task_instance.id)),
+ actual_start_date,
+ ti_patch_payload.end_date,
+ ti_patch_payload.reschedule_date,
+ )
)
- )
query = update(TI).where(TI.id == ti_id_str)
# calculate the duration for TI table too
- query = TI.duration_expression_update(ti_patch_payload.end_date,
query, session.bind)
+ if session.bind is not None:
+ query = TI.duration_expression_update(ti_patch_payload.end_date,
query, session.bind)
# clear the next_method and next_kwargs so that none of the retries
pick them up
query = query.values(state=TaskInstanceState.UP_FOR_RESCHEDULE,
next_method=None, next_kwargs=None)
updated_state = TaskInstanceState.UP_FOR_RESCHEDULE
@@ -565,7 +579,14 @@ def ti_skip_downstream(
now = timezone.utcnow()
tasks = ti_patch_payload.tasks
- dag_id, run_id = session.execute(select(TI.dag_id, TI.run_id).where(TI.id
== ti_id_str)).fetchone()
+ query_result = session.execute(select(TI.dag_id, TI.run_id).where(TI.id ==
ti_id_str))
+ row_result = query_result.fetchone()
+ if row_result is None:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail={"reason": "not_found", "message": "Task Instance not
found"},
+ )
+ dag_id, run_id = row_result
log.debug("Retrieved DAG and run info", dag_id=dag_id, run_id=run_id)
task_ids = [task if isinstance(task, tuple) else (task, -1) for task in
tasks]
@@ -579,7 +600,7 @@ def ti_skip_downstream(
)
result = session.execute(query)
- log.info("Downstream tasks skipped", tasks_skipped=result.rowcount)
+ log.info("Downstream tasks skipped", tasks_skipped=getattr(result,
"rowcount", 0))
@ti_id_router.put(
@@ -955,8 +976,10 @@ def validate_inlets_and_outlets(
with contextlib.suppress(TaskNotFound):
ti.task = dag.get_task(ti.task_id)
- inlets = [asset.asprofile() for asset in ti.task.inlets if
isinstance(asset, Asset)]
- outlets = [asset.asprofile() for asset in ti.task.outlets if
isinstance(asset, Asset)]
+ inlets = [asset.asprofile() for asset in (ti.task.inlets if ti.task else
[]) if isinstance(asset, Asset)]
+ outlets = [
+ asset.asprofile() for asset in (ti.task.outlets if ti.task else []) if
isinstance(asset, Asset)
+ ]
if not (inlets or outlets):
return InactiveAssetsResponse(inactive_assets=[])
diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
index 7be7150a747..f4e5f4943d6 100644
--- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
+++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
@@ -84,7 +84,12 @@ from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.retries import MAX_DB_RETRIES, retry_db_transaction,
run_with_db_retries
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.span_status import SpanStatus
-from airflow.utils.sqlalchemy import is_lock_not_available_error,
prohibit_commit, with_row_locks
+from airflow.utils.sqlalchemy import (
+ get_dialect_name,
+ is_lock_not_available_error,
+ prohibit_commit,
+ with_row_locks,
+)
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.thread_safe_dict import ThreadSafeDict
from airflow.utils.types import DagRunTriggeredByType, DagRunType
@@ -340,7 +345,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
executable_tis: list[TI] = []
- if session.get_bind().dialect.name == "postgresql":
+ if get_dialect_name(session) == "postgresql":
# Optimization: to avoid littering the DB errors of "ERROR:
canceling statement due to lock
# timeout", try to take out a transactional advisory lock (unlocks
automatically on
# COMMIT/ROLLBACK)
@@ -349,6 +354,8 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
id=DBLocks.SCHEDULER_CRITICAL_SECTION.value
)
).scalar()
+ if lock_acquired is None:
+ lock_acquired = False
if not lock_acquired:
# Throw an error like the one that would happen with NOWAIT
raise OperationalError(
@@ -1126,18 +1133,23 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
prefix, sep, key = prefixed_key.partition(":")
if prefix == "ti":
- ti: TaskInstance | None = session.get(TaskInstance, key)
+ ti_result = session.get(TaskInstance, key)
+ if ti_result is None:
+ continue
+ ti: TaskInstance = ti_result
- if ti is not None:
- if ti.state in State.finished:
- self.set_ti_span_attrs(span=span, state=ti.state,
ti=ti)
- span.end(end_time=datetime_to_nano(ti.end_date))
- ti.span_status = SpanStatus.ENDED
- else:
- span.end()
- ti.span_status = SpanStatus.NEEDS_CONTINUANCE
+ if ti.state in State.finished:
+ self.set_ti_span_attrs(span=span, state=ti.state, ti=ti)
+ span.end(end_time=datetime_to_nano(ti.end_date))
+ ti.span_status = SpanStatus.ENDED
+ else:
+ span.end()
+ ti.span_status = SpanStatus.NEEDS_CONTINUANCE
elif prefix == "dr":
- dag_run: DagRun =
session.scalars(select(DagRun).where(DagRun.id == int(key))).one()
+ dag_run_result =
session.scalars(select(DagRun).where(DagRun.id == int(key))).one_or_none()
+ if dag_run_result is None:
+ continue
+ dag_run: DagRun = dag_run_result
if dag_run.state in State.finished_dr_states:
dag_run.set_dagrun_span_attrs(span=span)
@@ -1213,12 +1225,15 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
and dag_run.span_status == SpanStatus.ACTIVE
):
initial_scheduler_id = dag_run.scheduled_by_job_id
- job: Job = session.scalars(
+ job_result = session.scalars(
select(Job).where(
Job.id == initial_scheduler_id,
Job.job_type == "SchedulerJob",
)
- ).one()
+ ).one_or_none()
+ if job_result is None:
+ return
+ job: Job = job_result
if not job.is_alive():
# Start a new span for the dag_run.
@@ -1916,9 +1931,15 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
dag = dag_run.dag =
self.scheduler_dag_bag.get_dag_for_run(dag_run=dag_run, session=session)
dag_model = DM.get_dagmodel(dag_run.dag_id, session)
+ if dag_model is None:
+ self.log.error("Couldn't find DAG model %s in database!",
dag_run.dag_id)
+ return callback
- if not dag or not dag_model:
- self.log.error("Couldn't find DAG %s in DAG bag or database!",
dag_run.dag_id)
+ if not dag:
+ self.log.error("Couldn't find DAG %s in DAG bag!",
dag_run.dag_id)
+ return callback
+ if not dag_model:
+ self.log.error("Couldn't find DAG model %s in database!",
dag_run.dag_id)
return callback
if (
@@ -2026,6 +2047,8 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
Return True if we determine that DAG still exists.
"""
latest_dag_version = DagVersion.get_latest_version(dag_run.dag_id,
session=session)
+ if latest_dag_version is None:
+ return False
if TYPE_CHECKING:
assert latest_dag_version
@@ -2199,10 +2222,11 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
Log.event == TASK_STUCK_IN_QUEUED_RESCHEDULE_EVENT,
)
- if last_running_time:
+ if last_running_time is not None:
query = query.where(Log.dttm > last_running_time)
- return query.count()
+ count_result = query.count()
+ return count_result if count_result is not None else 0
previous_ti_running_metrics: dict[tuple[str, str, str], int] = {}
@@ -2285,7 +2309,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
)
self.log.debug("Calling
SchedulerJob.adopt_or_reset_orphaned_tasks method")
try:
- num_failed = session.execute(
+ result = session.execute(
update(Job)
.where(
Job.job_type == "SchedulerJob",
@@ -2293,7 +2317,8 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
Job.latest_heartbeat < (timezone.utcnow() -
timedelta(seconds=timeout)),
)
.values(state=JobState.FAILED)
- ).rowcount
+ )
+ num_failed = getattr(result, "rowcount", 0)
if num_failed:
self.log.info("Marked %d SchedulerJob instances as
failed", num_failed)
@@ -2361,7 +2386,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
"""Mark any "deferred" task as failed if the trigger or execution
timeout has passed."""
for attempt in run_with_db_retries(max_retries, logger=self.log):
with attempt:
- num_timed_out_tasks = session.execute(
+ result = session.execute(
update(TI)
.where(
TI.state == TaskInstanceState.DEFERRED,
@@ -2374,7 +2399,8 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
scheduled_dttm=timezone.utcnow(),
trigger_id=None,
)
- ).rowcount
+ )
+ num_timed_out_tasks = getattr(result, "rowcount", 0)
if num_timed_out_tasks:
self.log.info("Timed out %i deferred tasks without fired
triggers", num_timed_out_tasks)
diff --git a/airflow-core/src/airflow/models/taskreschedule.py
b/airflow-core/src/airflow/models/taskreschedule.py
index 996bce62d0d..db36b580787 100644
--- a/airflow-core/src/airflow/models/taskreschedule.py
+++ b/airflow-core/src/airflow/models/taskreschedule.py
@@ -66,12 +66,12 @@ class TaskReschedule(Base):
def __init__(
self,
- ti_id: uuid.UUID,
+ ti_id: uuid.UUID | str,
start_date: datetime.datetime,
end_date: datetime.datetime,
reschedule_date: datetime.datetime,
) -> None:
- self.ti_id = ti_id
+ self.ti_id = str(ti_id)
self.start_date = start_date
self.end_date = end_date
self.reschedule_date = reschedule_date