This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5c3e0066aa6 Fix MyPy type errors 
/api_fastapi/execution_api/,/airflow/jobs/, /airflow/models/ in Sqlalchemy 2 
migration (#57277)
5c3e0066aa6 is described below

commit 5c3e0066aa6a85bd4c9695e1b9e6b5b4af9eb6f8
Author: Anusha Kovi <[email protected]>
AuthorDate: Mon Oct 27 19:12:53 2025 +0530

    Fix MyPy type errors /api_fastapi/execution_api/,/airflow/jobs/, 
/airflow/models/ in Sqlalchemy 2 migration (#57277)
---
 .../api_fastapi/execution_api/routes/dag_runs.py   |  5 ++
 .../execution_api/routes/task_instances.py         | 83 ++++++++++++++--------
 .../src/airflow/jobs/scheduler_job_runner.py       | 70 ++++++++++++------
 airflow-core/src/airflow/models/taskreschedule.py  |  4 +-
 4 files changed, 108 insertions(+), 54 deletions(-)

diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py
index 22237ee5e28..b3a99bf21a3 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py
@@ -129,6 +129,11 @@ def clear_dag_run(
     dag_run = session.scalar(
         select(DagRunModel).where(DagRunModel.dag_id == dag_id, 
DagRunModel.run_id == run_id)
     )
+    if dag_run is None:
+        raise HTTPException(
+            status.HTTP_404_NOT_FOUND,
+            detail={"reason": "not_found", "message": f"DAG run with run_id: 
'{run_id}' not found"},
+        )
     dag = get_dag_for_run(dag_bag, dag_run=dag_run, session=session)
 
     dag.clear(run_id=run_id)
diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index 2e81d313f0d..bf51488080c 100644
--- 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++ 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -210,7 +210,7 @@ def ti_run(
 
     try:
         result = session.execute(query)
-        log.info("Task instance state updated", rows_affected=result.rowcount)
+        log.info("Task instance state updated", rows_affected=getattr(result, 
"rowcount", 0))
 
         dr = (
             session.scalars(
@@ -235,15 +235,15 @@ def ti_run(
         xcom_keys = []
         if not ti.next_method:
             map_index = None if ti.map_index < 0 else ti.map_index
-            query = select(XComModel.key).where(
+            xcom_query = select(XComModel.key).where(
                 XComModel.dag_id == ti.dag_id,
                 XComModel.task_id == ti.task_id,
                 XComModel.run_id == ti.run_id,
             )
             if map_index is not None:
-                query = query.where(XComModel.map_index == map_index)
+                xcom_query = xcom_query.where(XComModel.map_index == map_index)
 
-            xcom_keys = list(session.scalars(query))
+            xcom_keys = list(session.scalars(xcom_query))
         task_reschedule_count = (
             session.query(
                 func.count(TaskReschedule.id)  # or any other primary key 
column
@@ -399,15 +399,21 @@ def ti_update_state(
         # Set a task to failed in case any unexpected exception happened 
during task state update
         log.exception("Error updating Task Instance state to %s. Set the task 
to failed", updated_state)
         ti = session.get(TI, ti_id_str)
-        query = TI.duration_expression_update(datetime.now(tz=timezone.utc), 
query, session.bind)
+        if session.bind is not None:
+            query = 
TI.duration_expression_update(datetime.now(tz=timezone.utc), query, 
session.bind)
         query = query.values(state=TaskInstanceState.FAILED)
-        _handle_fail_fast_for_dag(ti=ti, dag_id=dag_id, session=session, 
dag_bag=dag_bag)
+        if ti is not None:
+            _handle_fail_fast_for_dag(ti=ti, dag_id=dag_id, session=session, 
dag_bag=dag_bag)
 
     # TODO: Replace this with FastAPI's Custom Exception handling:
     # 
https://fastapi.tiangolo.com/tutorial/handling-errors/#install-custom-exception-handlers
     try:
         result = session.execute(query)
-        log.info("Task instance state updated", new_state=updated_state, 
rows_affected=result.rowcount)
+        log.info(
+            "Task instance state updated",
+            new_state=updated_state,
+            rows_affected=getattr(result, "rowcount", 0),
+        )
     except SQLAlchemyError as e:
         log.error("Error updating Task Instance state", error=str(e))
         raise HTTPException(
@@ -445,21 +451,25 @@ def _create_ti_state_update_query_and_update_state(
     if isinstance(ti_patch_payload, (TITerminalStatePayload, 
TIRetryStatePayload, TISuccessStatePayload)):
         ti = session.get(TI, ti_id_str)
         updated_state = ti_patch_payload.state
-        query = TI.duration_expression_update(ti_patch_payload.end_date, 
query, session.bind)
+        if session.bind is not None:
+            query = TI.duration_expression_update(ti_patch_payload.end_date, 
query, session.bind)
         query = query.values(state=updated_state, next_method=None, 
next_kwargs=None)
 
         if updated_state == TerminalTIState.FAILED:
             # This is the only case needs extra handling for 
TITerminalStatePayload
-            _handle_fail_fast_for_dag(ti=ti, dag_id=dag_id, session=session, 
dag_bag=dag_bag)
+            if ti is not None:
+                _handle_fail_fast_for_dag(ti=ti, dag_id=dag_id, 
session=session, dag_bag=dag_bag)
         elif isinstance(ti_patch_payload, TIRetryStatePayload):
-            ti.prepare_db_for_next_try(session)
+            if ti is not None:
+                ti.prepare_db_for_next_try(session)
         elif isinstance(ti_patch_payload, TISuccessStatePayload):
-            TI.register_asset_changes_in_db(
-                ti,
-                ti_patch_payload.task_outlets,  # type: ignore
-                ti_patch_payload.outlet_events,
-                session,
-            )
+            if ti is not None:
+                TI.register_asset_changes_in_db(
+                    ti,
+                    ti_patch_payload.task_outlets,  # type: ignore
+                    ti_patch_payload.outlet_events,
+                    session,
+                )
     elif isinstance(ti_patch_payload, TIDeferredStatePayload):
         # Calculate timeout if it was passed
         timeout = None
@@ -516,26 +526,30 @@ def _create_ti_state_update_query_and_update_state(
                 )
                 data = 
ti_patch_payload.model_dump(exclude={"reschedule_date"}, exclude_unset=True)
                 query = update(TI).where(TI.id == ti_id_str).values(data)
-                query = 
TI.duration_expression_update(datetime.now(tz=timezone.utc), query, 
session.bind)
+                if session.bind is not None:
+                    query = 
TI.duration_expression_update(datetime.now(tz=timezone.utc), query, 
session.bind)
                 query = query.values(state=TaskInstanceState.FAILED)
                 ti = session.get(TI, ti_id_str)
-                _handle_fail_fast_for_dag(ti=ti, dag_id=dag_id, 
session=session, dag_bag=dag_bag)
+                if ti is not None:
+                    _handle_fail_fast_for_dag(ti=ti, dag_id=dag_id, 
session=session, dag_bag=dag_bag)
                 return query, updated_state
 
         task_instance = session.get(TI, ti_id_str)
         actual_start_date = timezone.utcnow()
-        session.add(
-            TaskReschedule(
-                task_instance.id,
-                actual_start_date,
-                ti_patch_payload.end_date,
-                ti_patch_payload.reschedule_date,
+        if task_instance is not None and task_instance.id is not None:
+            session.add(
+                TaskReschedule(
+                    UUID(str(task_instance.id)),
+                    actual_start_date,
+                    ti_patch_payload.end_date,
+                    ti_patch_payload.reschedule_date,
+                )
             )
-        )
 
         query = update(TI).where(TI.id == ti_id_str)
         # calculate the duration for TI table too
-        query = TI.duration_expression_update(ti_patch_payload.end_date, 
query, session.bind)
+        if session.bind is not None:
+            query = TI.duration_expression_update(ti_patch_payload.end_date, 
query, session.bind)
         # clear the next_method and next_kwargs so that none of the retries 
pick them up
         query = query.values(state=TaskInstanceState.UP_FOR_RESCHEDULE, 
next_method=None, next_kwargs=None)
         updated_state = TaskInstanceState.UP_FOR_RESCHEDULE
@@ -565,7 +579,14 @@ def ti_skip_downstream(
     now = timezone.utcnow()
     tasks = ti_patch_payload.tasks
 
-    dag_id, run_id = session.execute(select(TI.dag_id, TI.run_id).where(TI.id 
== ti_id_str)).fetchone()
+    query_result = session.execute(select(TI.dag_id, TI.run_id).where(TI.id == 
ti_id_str))
+    row_result = query_result.fetchone()
+    if row_result is None:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail={"reason": "not_found", "message": "Task Instance not 
found"},
+        )
+    dag_id, run_id = row_result
     log.debug("Retrieved DAG and run info", dag_id=dag_id, run_id=run_id)
 
     task_ids = [task if isinstance(task, tuple) else (task, -1) for task in 
tasks]
@@ -579,7 +600,7 @@ def ti_skip_downstream(
     )
 
     result = session.execute(query)
-    log.info("Downstream tasks skipped", tasks_skipped=result.rowcount)
+    log.info("Downstream tasks skipped", tasks_skipped=getattr(result, 
"rowcount", 0))
 
 
 @ti_id_router.put(
@@ -955,8 +976,10 @@ def validate_inlets_and_outlets(
             with contextlib.suppress(TaskNotFound):
                 ti.task = dag.get_task(ti.task_id)
 
-    inlets = [asset.asprofile() for asset in ti.task.inlets if 
isinstance(asset, Asset)]
-    outlets = [asset.asprofile() for asset in ti.task.outlets if 
isinstance(asset, Asset)]
+    inlets = [asset.asprofile() for asset in (ti.task.inlets if ti.task else 
[]) if isinstance(asset, Asset)]
+    outlets = [
+        asset.asprofile() for asset in (ti.task.outlets if ti.task else []) if 
isinstance(asset, Asset)
+    ]
     if not (inlets or outlets):
         return InactiveAssetsResponse(inactive_assets=[])
 
diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py 
b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
index 7be7150a747..f4e5f4943d6 100644
--- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
+++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
@@ -84,7 +84,12 @@ from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.retries import MAX_DB_RETRIES, retry_db_transaction, 
run_with_db_retries
 from airflow.utils.session import NEW_SESSION, create_session, provide_session
 from airflow.utils.span_status import SpanStatus
-from airflow.utils.sqlalchemy import is_lock_not_available_error, 
prohibit_commit, with_row_locks
+from airflow.utils.sqlalchemy import (
+    get_dialect_name,
+    is_lock_not_available_error,
+    prohibit_commit,
+    with_row_locks,
+)
 from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.thread_safe_dict import ThreadSafeDict
 from airflow.utils.types import DagRunTriggeredByType, DagRunType
@@ -340,7 +345,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
 
         executable_tis: list[TI] = []
 
-        if session.get_bind().dialect.name == "postgresql":
+        if get_dialect_name(session) == "postgresql":
             # Optimization: to avoid littering the DB errors of "ERROR: 
canceling statement due to lock
             # timeout", try to take out a transactional advisory lock (unlocks 
automatically on
             # COMMIT/ROLLBACK)
@@ -349,6 +354,8 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
                     id=DBLocks.SCHEDULER_CRITICAL_SECTION.value
                 )
             ).scalar()
+            if lock_acquired is None:
+                lock_acquired = False
             if not lock_acquired:
                 # Throw an error like the one that would happen with NOWAIT
                 raise OperationalError(
@@ -1126,18 +1133,23 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
             prefix, sep, key = prefixed_key.partition(":")
 
             if prefix == "ti":
-                ti: TaskInstance | None = session.get(TaskInstance, key)
+                ti_result = session.get(TaskInstance, key)
+                if ti_result is None:
+                    continue
+                ti: TaskInstance = ti_result
 
-                if ti is not None:
-                    if ti.state in State.finished:
-                        self.set_ti_span_attrs(span=span, state=ti.state, 
ti=ti)
-                        span.end(end_time=datetime_to_nano(ti.end_date))
-                        ti.span_status = SpanStatus.ENDED
-                    else:
-                        span.end()
-                        ti.span_status = SpanStatus.NEEDS_CONTINUANCE
+                if ti.state in State.finished:
+                    self.set_ti_span_attrs(span=span, state=ti.state, ti=ti)
+                    span.end(end_time=datetime_to_nano(ti.end_date))
+                    ti.span_status = SpanStatus.ENDED
+                else:
+                    span.end()
+                    ti.span_status = SpanStatus.NEEDS_CONTINUANCE
             elif prefix == "dr":
-                dag_run: DagRun = 
session.scalars(select(DagRun).where(DagRun.id == int(key))).one()
+                dag_run_result = 
session.scalars(select(DagRun).where(DagRun.id == int(key))).one_or_none()
+                if dag_run_result is None:
+                    continue
+                dag_run: DagRun = dag_run_result
                 if dag_run.state in State.finished_dr_states:
                     dag_run.set_dagrun_span_attrs(span=span)
 
@@ -1213,12 +1225,15 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
             and dag_run.span_status == SpanStatus.ACTIVE
         ):
             initial_scheduler_id = dag_run.scheduled_by_job_id
-            job: Job = session.scalars(
+            job_result = session.scalars(
                 select(Job).where(
                     Job.id == initial_scheduler_id,
                     Job.job_type == "SchedulerJob",
                 )
-            ).one()
+            ).one_or_none()
+            if job_result is None:
+                return
+            job: Job = job_result
 
             if not job.is_alive():
                 # Start a new span for the dag_run.
@@ -1916,9 +1931,15 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
 
             dag = dag_run.dag = 
self.scheduler_dag_bag.get_dag_for_run(dag_run=dag_run, session=session)
             dag_model = DM.get_dagmodel(dag_run.dag_id, session)
+            if dag_model is None:
+                self.log.error("Couldn't find DAG model %s in database!", 
dag_run.dag_id)
+                return callback
 
-            if not dag or not dag_model:
-                self.log.error("Couldn't find DAG %s in DAG bag or database!", 
dag_run.dag_id)
+            if not dag:
+                self.log.error("Couldn't find DAG %s in DAG bag!", 
dag_run.dag_id)
+                return callback
+            if not dag_model:
+                self.log.error("Couldn't find DAG model %s in database!", 
dag_run.dag_id)
                 return callback
 
             if (
@@ -2026,6 +2047,8 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
         Return True if we determine that DAG still exists.
         """
         latest_dag_version = DagVersion.get_latest_version(dag_run.dag_id, 
session=session)
+        if latest_dag_version is None:
+            return False
         if TYPE_CHECKING:
             assert latest_dag_version
 
@@ -2199,10 +2222,11 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
             Log.event == TASK_STUCK_IN_QUEUED_RESCHEDULE_EVENT,
         )
 
-        if last_running_time:
+        if last_running_time is not None:
             query = query.where(Log.dttm > last_running_time)
 
-        return query.count()
+        count_result = query.count()
+        return count_result if count_result is not None else 0
 
     previous_ti_running_metrics: dict[tuple[str, str, str], int] = {}
 
@@ -2285,7 +2309,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
                 )
                 self.log.debug("Calling 
SchedulerJob.adopt_or_reset_orphaned_tasks method")
                 try:
-                    num_failed = session.execute(
+                    result = session.execute(
                         update(Job)
                         .where(
                             Job.job_type == "SchedulerJob",
@@ -2293,7 +2317,8 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
                             Job.latest_heartbeat < (timezone.utcnow() - 
timedelta(seconds=timeout)),
                         )
                         .values(state=JobState.FAILED)
-                    ).rowcount
+                    )
+                    num_failed = getattr(result, "rowcount", 0)
 
                     if num_failed:
                         self.log.info("Marked %d SchedulerJob instances as 
failed", num_failed)
@@ -2361,7 +2386,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
         """Mark any "deferred" task as failed if the trigger or execution 
timeout has passed."""
         for attempt in run_with_db_retries(max_retries, logger=self.log):
             with attempt:
-                num_timed_out_tasks = session.execute(
+                result = session.execute(
                     update(TI)
                     .where(
                         TI.state == TaskInstanceState.DEFERRED,
@@ -2374,7 +2399,8 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
                         scheduled_dttm=timezone.utcnow(),
                         trigger_id=None,
                     )
-                ).rowcount
+                )
+                num_timed_out_tasks = getattr(result, "rowcount", 0)
                 if num_timed_out_tasks:
                     self.log.info("Timed out %i deferred tasks without fired 
triggers", num_timed_out_tasks)
 
diff --git a/airflow-core/src/airflow/models/taskreschedule.py 
b/airflow-core/src/airflow/models/taskreschedule.py
index 996bce62d0d..db36b580787 100644
--- a/airflow-core/src/airflow/models/taskreschedule.py
+++ b/airflow-core/src/airflow/models/taskreschedule.py
@@ -66,12 +66,12 @@ class TaskReschedule(Base):
 
     def __init__(
         self,
-        ti_id: uuid.UUID,
+        ti_id: uuid.UUID | str,
         start_date: datetime.datetime,
         end_date: datetime.datetime,
         reschedule_date: datetime.datetime,
     ) -> None:
-        self.ti_id = ti_id
+        self.ti_id = str(ti_id)
         self.start_date = start_date
         self.end_date = end_date
         self.reschedule_date = reschedule_date

Reply via email to