This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1a8d872c6d5 Fix MyPy type errors in
/api_fastapi/core_api/routes/public/ for Sqlalchemy 2 migration (#57317)
1a8d872c6d5 is described below
commit 1a8d872c6d576736539e9b6f0b7fb786388ed1df
Author: Anusha Kovi <[email protected]>
AuthorDate: Mon Oct 27 18:57:26 2025 +0530
Fix MyPy type errors in /api_fastapi/core_api/routes/public/ for Sqlalchemy
2 migration (#57317)
---
.../core_api/routes/public/connections.py | 4 +-
.../api_fastapi/core_api/routes/public/dag_run.py | 40 +++++++-------
.../core_api/routes/public/dag_versions.py | 2 +-
.../core_api/routes/public/dag_warning.py | 2 +-
.../api_fastapi/core_api/routes/public/dags.py | 8 +--
.../core_api/routes/public/event_logs.py | 2 +-
.../api_fastapi/core_api/routes/public/hitl.py | 2 +-
.../core_api/routes/public/import_error.py | 4 +-
.../api_fastapi/core_api/routes/public/job.py | 2 +-
.../core_api/routes/public/task_instances.py | 63 +++++++++++-----------
dev/breeze/src/airflow_breeze/global_constants.py | 2 +
.../src/airflow_breeze/utils/selective_checks.py | 2 +-
12 files changed, 71 insertions(+), 62 deletions(-)
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py
index 28db5ae4fef..0b669cd236c 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py
@@ -138,7 +138,7 @@ def get_connections(
connections = session.scalars(connection_select)
return ConnectionCollectionResponse(
- connections=connections,
+ connections=list(connections),
total_entries=total_entries,
)
@@ -195,7 +195,7 @@ def patch_connection(
"The connection_id in the request body does not match the URL
parameter",
)
- connection: Connection =
session.scalar(select(Connection).filter_by(conn_id=connection_id).limit(1))
+ connection =
session.scalar(select(Connection).filter_by(conn_id=connection_id).limit(1))
if connection is None:
raise HTTPException(
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_run.py
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_run.py
index fa1b0780519..a1edd7eea10 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_run.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_run.py
@@ -205,16 +205,18 @@ def patch_dag_run(
except Exception:
log.exception("error calling listener")
elif attr_name == "note":
- dag_run = session.get(DagRun, dag_run.id)
- if dag_run.dag_run_note is None:
- dag_run.note = (attr_value, user.get_id())
- else:
- dag_run.dag_run_note.content = attr_value
- dag_run.dag_run_note.user_id = user.get_id()
+ updated_dag_run = session.get(DagRun, dag_run.id)
+ if updated_dag_run and updated_dag_run.dag_run_note is None:
+ updated_dag_run.note = (attr_value, user.get_id())
+ elif updated_dag_run:
+ updated_dag_run.dag_run_note.content = attr_value
+ updated_dag_run.dag_run_note.user_id = user.get_id()
- dag_run = session.get(DagRun, dag_run.id)
+ final_dag_run = session.get(DagRun, dag_run.id)
+ if not final_dag_run:
+ raise HTTPException(status.HTTP_404_NOT_FOUND, "DAG run not found
after update")
- return dag_run
+ return final_dag_run
@dag_run_router.get(
@@ -303,6 +305,8 @@ def clear_dag_run(
session=session,
)
dag_run_cleared = session.scalar(select(DagRun).where(DagRun.id ==
dag_run.id))
+ if not dag_run_cleared:
+ raise HTTPException(status.HTTP_404_NOT_FOUND, "DAG run not found
after clearing")
return dag_run_cleared
@@ -401,7 +405,7 @@ def get_dag_runs(
dag_runs = session.scalars(dag_run_select)
return DAGRunCollectionResponse(
- dag_runs=dag_runs,
+ dag_runs=list(dag_runs),
total_entries=total_entries,
)
@@ -537,7 +541,7 @@ def get_list_dag_runs_batch(
session: SessionDep,
) -> DAGRunCollectionResponse:
"""Get a list of DAG Runs."""
- dag_ids = FilterParam(DagRun.dag_id, body.dag_ids, FilterOptionEnum.IN)
+ dag_ids = FilterParam(DagRun.dag_id, body.dag_ids, FilterOptionEnum.IN) #
type: ignore[arg-type]
logical_date = RangeFilter(
Range(
lower_bound_gte=body.logical_date_gte,
@@ -545,7 +549,7 @@ def get_list_dag_runs_batch(
upper_bound_lte=body.logical_date_lte,
upper_bound_lt=body.logical_date_lt,
),
- attribute=DagRun.logical_date,
+ attribute=DagRun.logical_date, # type: ignore[arg-type]
)
run_after = RangeFilter(
Range(
@@ -554,7 +558,7 @@ def get_list_dag_runs_batch(
upper_bound_lte=body.run_after_lte,
upper_bound_lt=body.run_after_lt,
),
- attribute=DagRun.run_after,
+ attribute=DagRun.run_after, # type: ignore[arg-type]
)
start_date = RangeFilter(
Range(
@@ -563,7 +567,7 @@ def get_list_dag_runs_batch(
upper_bound_lte=body.start_date_lte,
upper_bound_lt=body.start_date_lt,
),
- attribute=DagRun.start_date,
+ attribute=DagRun.start_date, # type: ignore[arg-type]
)
end_date = RangeFilter(
Range(
@@ -572,7 +576,7 @@ def get_list_dag_runs_batch(
upper_bound_lte=body.end_date_lte,
upper_bound_lt=body.end_date_lt,
),
- attribute=DagRun.end_date,
+ attribute=DagRun.end_date, # type: ignore[arg-type]
)
duration = RangeFilter(
Range(
@@ -581,10 +585,10 @@ def get_list_dag_runs_batch(
upper_bound_lte=body.duration_lte,
upper_bound_lt=body.duration_lt,
),
- attribute=DagRun.duration,
+ attribute=DagRun.duration, # type: ignore[arg-type]
)
- conf_contains = FilterParam(DagRun.conf, body.conf_contains,
FilterOptionEnum.CONTAINS)
- state = FilterParam(DagRun.state, body.states, FilterOptionEnum.ANY_EQUAL)
+ conf_contains = FilterParam(DagRun.conf, body.conf_contains,
FilterOptionEnum.CONTAINS) # type: ignore[arg-type]
+ state = FilterParam(DagRun.state, body.states, FilterOptionEnum.ANY_EQUAL)
# type: ignore[arg-type]
offset = OffsetFilter(body.page_offset)
limit = LimitFilter(body.page_limit)
@@ -629,6 +633,6 @@ def get_list_dag_runs_batch(
dag_runs = session.scalars(dag_runs_select)
return DAGRunCollectionResponse(
- dag_runs=dag_runs,
+ dag_runs=list(dag_runs),
total_entries=total_entries,
)
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_versions.py
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_versions.py
index e4f282f6088..ab8476bb546 100644
---
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_versions.py
+++
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_versions.py
@@ -125,6 +125,6 @@ def get_dag_versions(
dag_versions = session.scalars(dag_versions_select)
return DAGVersionCollectionResponse(
- dag_versions=dag_versions,
+ dag_versions=list(dag_versions),
total_entries=total_entries,
)
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_warning.py
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_warning.py
index 3884a165e99..3a60500382b 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_warning.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_warning.py
@@ -76,6 +76,6 @@ def list_dag_warnings(
dag_warnings = session.scalars(dag_warnings_select)
return DAGWarningCollectionResponse(
- dag_warnings=dag_warnings,
+ dag_warnings=list(dag_warnings),
total_entries=total_entries,
)
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dags.py
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dags.py
index a50c1b35069..5b3b154ce3c 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dags.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dags.py
@@ -165,7 +165,7 @@ def get_dags(
dags = session.scalars(dags_select)
return DAGCollectionResponse(
- dags=dags,
+ dags=list(dags),
total_entries=total_entries,
)
@@ -188,7 +188,7 @@ def get_dag(
) -> DAGResponse:
"""Get basic information about a DAG."""
dag = get_latest_version_of_dag(dag_bag, dag_id, session)
- dag_model: DagModel = session.get(DagModel, dag_id)
+ dag_model = session.get(DagModel, dag_id)
if not dag_model:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Unable to obtain dag
with id {dag_id} from session")
@@ -215,7 +215,7 @@ def get_dag_details(
"""Get details of DAG."""
dag = get_latest_version_of_dag(dag_bag, dag_id, session)
- dag_model: DagModel = session.get(DagModel, dag_id)
+ dag_model = session.get(DagModel, dag_id)
if not dag_model:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Unable to obtain dag
with id {dag_id} from session")
@@ -341,7 +341,7 @@ def patch_dags(
)
return DAGCollectionResponse(
- dags=dags,
+ dags=list(dags),
total_entries=total_entries,
)
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/event_logs.py
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/event_logs.py
index 1bce6a65a50..fc07f91e992 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/event_logs.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/event_logs.py
@@ -159,6 +159,6 @@ def get_event_logs(
event_logs = session.scalars(event_logs_select)
return EventLogCollectionResponse(
- event_logs=event_logs,
+ event_logs=list(event_logs),
total_entries=total_entries,
)
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/hitl.py
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/hitl.py
index af69425bb01..ecc266c1e58 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/hitl.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/hitl.py
@@ -289,6 +289,6 @@ def get_hitl_details(
hitl_details = session.scalars(hitl_detail_select)
return HITLDetailCollection(
- hitl_details=hitl_details,
+ hitl_details=list(hitl_details),
total_entries=total_entries,
)
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/import_error.py
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/import_error.py
index e9d6eaddb02..bfc31f5c5b9 100644
---
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/import_error.py
+++
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/import_error.py
@@ -147,7 +147,7 @@ def get_import_errors(
# Early return if the user has access to all DAGs
import_errors = session.scalars(import_errors_select).all()
return ImportErrorCollectionResponse(
- import_errors=import_errors,
+ import_errors=list(import_errors),
total_entries=total_entries,
)
@@ -205,6 +205,6 @@ def get_import_errors(
import_errors.append(import_error)
return ImportErrorCollectionResponse(
- import_errors=import_errors,
+ import_errors=list(import_errors),
total_entries=total_entries,
)
diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/job.py
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/job.py
index fa6c3b0a71e..2fad029bd57 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/job.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/job.py
@@ -130,6 +130,6 @@ def get_jobs(
jobs = [job for job in jobs if job.is_alive()]
return JobCollectionResponse(
- jobs=jobs,
+ jobs=list(jobs),
total_entries=total_entries,
)
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py
index 28485e945f3..9ea513b5b9b 100644
---
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py
+++
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py
@@ -236,7 +236,7 @@ def get_mapped_task_instances(
task_instances = session.scalars(task_instance_select)
return TaskInstanceCollectionResponse(
- task_instances=task_instances,
+ task_instances=list(task_instances),
total_entries=total_entries,
)
@@ -278,7 +278,10 @@ def get_task_instance_dependencies(
if ti.state in [None, TaskInstanceState.SCHEDULED]:
dag_run = session.scalar(select(DagRun).where(DagRun.dag_id ==
ti.dag_id, DagRun.run_id == ti.run_id))
- dag = dag_bag.get_dag_for_run(dag_run, session=session)
+ if dag_run:
+ dag = dag_bag.get_dag_for_run(dag_run, session=session)
+ else:
+ dag = None
if dag:
try:
@@ -331,7 +334,7 @@ def get_task_instance_tries(
tis = session.scalars(
_query(TI).where(or_(TI.state != TaskInstanceState.UP_FOR_RETRY,
TI.state.is_(None)))
).all()
- task_instances = session.scalars(_query(TIH)).all() + tis
+ task_instances = list(session.scalars(_query(TIH)).all()) + list(tis)
if not task_instances:
raise HTTPException(
@@ -512,7 +515,7 @@ def get_task_instances(
task_instances = session.scalars(task_instance_select)
return TaskInstanceCollectionResponse(
- task_instances=task_instances,
+ task_instances=list(task_instances),
total_entries=total_entries,
)
@@ -533,9 +536,9 @@ def get_task_instances_batch(
session: SessionDep,
) -> TaskInstanceCollectionResponse:
"""Get list of task instances."""
- dag_ids = FilterParam(TI.dag_id, body.dag_ids, FilterOptionEnum.IN)
- dag_run_ids = FilterParam(TI.run_id, body.dag_run_ids, FilterOptionEnum.IN)
- task_ids = FilterParam(TI.task_id, body.task_ids, FilterOptionEnum.IN)
+ dag_ids = FilterParam(TI.dag_id, body.dag_ids, FilterOptionEnum.IN) #
type: ignore[arg-type]
+ dag_run_ids = FilterParam(TI.run_id, body.dag_run_ids,
FilterOptionEnum.IN) # type: ignore[arg-type]
+ task_ids = FilterParam(TI.task_id, body.task_ids, FilterOptionEnum.IN) #
type: ignore[arg-type]
run_after = RangeFilter(
Range(
lower_bound_gte=body.run_after_gte,
@@ -543,7 +546,7 @@ def get_task_instances_batch(
upper_bound_lte=body.run_after_lte,
upper_bound_lt=body.run_after_lt,
),
- attribute=TI.run_after,
+ attribute=TI.run_after, # type: ignore[arg-type]
)
logical_date = RangeFilter(
Range(
@@ -552,7 +555,7 @@ def get_task_instances_batch(
upper_bound_lte=body.logical_date_lte,
upper_bound_lt=body.logical_date_lt,
),
- attribute=TI.logical_date,
+ attribute=TI.logical_date, # type: ignore[arg-type]
)
start_date = RangeFilter(
Range(
@@ -561,7 +564,7 @@ def get_task_instances_batch(
upper_bound_lte=body.start_date_lte,
upper_bound_lt=body.start_date_lt,
),
- attribute=TI.start_date,
+ attribute=TI.start_date, # type: ignore[arg-type]
)
end_date = RangeFilter(
Range(
@@ -570,7 +573,7 @@ def get_task_instances_batch(
upper_bound_lte=body.end_date_lte,
upper_bound_lt=body.end_date_lt,
),
- attribute=TI.end_date,
+ attribute=TI.end_date, # type: ignore[arg-type]
)
duration = RangeFilter(
Range(
@@ -579,12 +582,12 @@ def get_task_instances_batch(
upper_bound_lte=body.duration_lte,
upper_bound_lt=body.duration_lt,
),
- attribute=TI.duration,
+ attribute=TI.duration, # type: ignore[arg-type]
)
- state = FilterParam(TI.state, body.state, FilterOptionEnum.ANY_EQUAL)
- pool = FilterParam(TI.pool, body.pool, FilterOptionEnum.ANY_EQUAL)
- queue = FilterParam(TI.queue, body.queue, FilterOptionEnum.ANY_EQUAL)
- executor = FilterParam(TI.executor, body.executor,
FilterOptionEnum.ANY_EQUAL)
+ state = FilterParam(TI.state, body.state, FilterOptionEnum.ANY_EQUAL) #
type: ignore[arg-type]
+ pool = FilterParam(TI.pool, body.pool, FilterOptionEnum.ANY_EQUAL) #
type: ignore[arg-type]
+ queue = FilterParam(TI.queue, body.queue, FilterOptionEnum.ANY_EQUAL) #
type: ignore[arg-type]
+ executor = FilterParam(TI.executor, body.executor,
FilterOptionEnum.ANY_EQUAL) # type: ignore[arg-type]
offset = OffsetFilter(body.page_offset)
limit = LimitFilter(body.page_limit)
@@ -626,7 +629,7 @@ def get_task_instances_batch(
task_instances = session.scalars(task_instance_select)
return TaskInstanceCollectionResponse(
- task_instances=task_instances,
+ task_instances=list(task_instances),
total_entries=total_entries,
)
@@ -765,28 +768,28 @@ def post_clear_task_instances(
]
task_ids = mapped_tasks_list + list(unmapped_task_ids)
- # Prepare common parameters
- common_params = {
- "dry_run": True,
- "task_ids": task_ids,
- "session": session,
- "run_on_latest_version": body.run_on_latest_version,
- "only_failed": body.only_failed,
- "only_running": body.only_running,
- }
-
if dag_run_id is not None and not (past or future):
# Use run_id-based clearing when we have a specific dag_run_id and not
using past/future
task_instances = dag.clear(
- **common_params,
+ dry_run=True,
+ task_ids=task_ids,
run_id=dag_run_id,
+ session=session,
+ run_on_latest_version=body.run_on_latest_version,
+ only_failed=body.only_failed,
+ only_running=body.only_running,
)
else:
# Use date-based clearing when no dag_run_id or when past/future is
specified
task_instances = dag.clear(
- **common_params,
+ dry_run=True,
+ task_ids=task_ids,
start_date=body.start_date,
end_date=body.end_date,
+ session=session,
+ run_on_latest_version=body.run_on_latest_version,
+ only_failed=body.only_failed,
+ only_running=body.only_running,
)
if not dry_run:
@@ -798,7 +801,7 @@ def post_clear_task_instances(
)
return TaskInstanceCollectionResponse(
- task_instances=task_instances,
+ task_instances=cast("list[TaskInstanceResponse]", task_instances),
total_entries=len(task_instances),
)
diff --git a/dev/breeze/src/airflow_breeze/global_constants.py
b/dev/breeze/src/airflow_breeze/global_constants.py
index 91df0f2704f..81b37670a56 100644
--- a/dev/breeze/src/airflow_breeze/global_constants.py
+++ b/dev/breeze/src/airflow_breeze/global_constants.py
@@ -44,6 +44,8 @@ PUBLIC_ARM_RUNNERS = '["ubuntu-22.04-arm"]'
RUNNERS_TYPE_CROSS_MAPPING = {
"ubuntu-22.04": '["ubuntu-22.04-arm"]',
"ubuntu-22.04-arm": '["ubuntu-22.04"]',
+ "windows-2022": '["windows-2022"]',
+ "windows-2025": '["windows-2025"]',
}
ANSWER = ""
diff --git a/dev/breeze/src/airflow_breeze/utils/selective_checks.py
b/dev/breeze/src/airflow_breeze/utils/selective_checks.py
index f1185222d44..57d024d3531 100644
--- a/dev/breeze/src/airflow_breeze/utils/selective_checks.py
+++ b/dev/breeze/src/airflow_breeze/utils/selective_checks.py
@@ -1345,7 +1345,7 @@ class SelectiveChecks:
branch = self._github_context_dict.get("ref_name", "main")
label =
self.get_job_label(event_type=str(self._github_event.value), branch=branch)
- return RUNNERS_TYPE_CROSS_MAPPING[label] if label else
PUBLIC_AMD_RUNNERS
+ return RUNNERS_TYPE_CROSS_MAPPING.get(label, PUBLIC_AMD_RUNNERS)
if label else PUBLIC_AMD_RUNNERS
return PUBLIC_AMD_RUNNERS