This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 83f5950976 Refactor Sqlalchemy queries to 2.0 styles (Part 2) (#31772)
83f5950976 is described below
commit 83f595097606bbe48d716b1c94eace7bd6ac4e77
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Tue Jun 27 15:09:05 2023 +0100
Refactor Sqlalchemy queries to 2.0 styles (Part 2) (#31772)
* Refactor Sqlalchemy queries to 2.0 styles (Part 2)
This is a continuation of the effort to refactor the queries to sqlalchemy
2.0 style
* Fix total_entries counting logic
The count should happen *after* we apply the filters, and is (slightly)
different from the automatic 404 case.
* Update queries in models/
* fixup! Update queries in models/
* Add triggerer_job in query
* fixup! fixup! Update queries in models/
* Revert changes to dagwarning.py
* Use session.scalar
* Convert DagWarning to SQLAlchemy 2.x syntax
* Fix dagwarning
* fix typing and remove todo comment
* Apply suggestions from code review
Co-authored-by: Tzu-ping Chung <[email protected]>
* Use scalars instead of execute and fix typing
* Apply suggestions from code review
Co-authored-by: Tzu-ping Chung <[email protected]>
* Apply suggestions from code review
Co-authored-by: Tzu-ping Chung <[email protected]>
* remove all
* Remove one_or_none where possible
---------
Co-authored-by: Tzu-ping Chung <[email protected]>
---
.../api_connexion/endpoints/connection_endpoint.py | 17 +-
airflow/api_connexion/endpoints/dag_endpoint.py | 44 +--
.../api_connexion/endpoints/dag_run_endpoint.py | 78 ++---
.../endpoints/dag_warning_endpoint.py | 11 +-
.../api_connexion/endpoints/dataset_endpoint.py | 36 ++-
.../api_connexion/endpoints/event_log_endpoint.py | 8 +-
.../api_connexion/endpoints/extra_link_endpoint.py | 7 +-
.../endpoints/import_error_endpoint.py | 8 +-
airflow/api_connexion/endpoints/log_endpoint.py | 12 +-
airflow/api_connexion/endpoints/pool_endpoint.py | 12 +-
.../endpoints/role_and_permission_endpoint.py | 18 +-
.../endpoints/task_instance_endpoint.py | 141 +++++----
airflow/api_connexion/endpoints/user_endpoint.py | 8 +-
.../api_connexion/endpoints/variable_endpoint.py | 16 +-
airflow/api_connexion/endpoints/xcom_endpoint.py | 30 +-
airflow/api_connexion/parameters.py | 6 +-
airflow/models/abstractoperator.py | 35 ++-
airflow/models/baseoperator.py | 34 +--
airflow/models/dag.py | 333 +++++++++++----------
airflow/models/dagcode.py | 16 +-
airflow/models/dagrun.py | 183 +++++------
airflow/models/dagwarning.py | 11 +-
airflow/models/pool.py | 64 ++--
.../test_mapped_task_instance_endpoint.py | 6 +-
tests/models/test_dagwarning.py | 11 +-
25 files changed, 597 insertions(+), 548 deletions(-)
diff --git a/airflow/api_connexion/endpoints/connection_endpoint.py
b/airflow/api_connexion/endpoints/connection_endpoint.py
index 641c3288d3..737ee54d6c 100644
--- a/airflow/api_connexion/endpoints/connection_endpoint.py
+++ b/airflow/api_connexion/endpoints/connection_endpoint.py
@@ -22,7 +22,7 @@ from http import HTTPStatus
from connexion import NoContent
from flask import request
from marshmallow import ValidationError
-from sqlalchemy import func
+from sqlalchemy import func, select
from sqlalchemy.orm import Session
from airflow.api_connexion import security
@@ -58,7 +58,7 @@ RESOURCE_EVENT_PREFIX = "connection"
)
def delete_connection(*, connection_id: str, session: Session = NEW_SESSION)
-> APIResponse:
"""Delete a connection entry."""
- connection =
session.query(Connection).filter_by(conn_id=connection_id).one_or_none()
+ connection =
session.scalar(select(Connection).filter_by(conn_id=connection_id))
if connection is None:
raise NotFound(
"Connection not found",
@@ -72,7 +72,7 @@ def delete_connection(*, connection_id: str, session: Session
= NEW_SESSION) ->
@provide_session
def get_connection(*, connection_id: str, session: Session = NEW_SESSION) ->
APIResponse:
"""Get a connection entry."""
- connection = session.query(Connection).filter(Connection.conn_id ==
connection_id).one_or_none()
+ connection = session.scalar(select(Connection).where(Connection.conn_id ==
connection_id))
if connection is None:
raise NotFound(
"Connection not found",
@@ -95,10 +95,10 @@ def get_connections(
to_replace = {"connection_id": "conn_id"}
allowed_filter_attrs = ["connection_id", "conn_type", "description",
"host", "port", "id"]
- total_entries = session.query(func.count(Connection.id)).scalar()
- query = session.query(Connection)
+ total_entries =
session.execute(select(func.count(Connection.id))).scalar_one()
+ query = select(Connection)
query = apply_sorting(query, order_by, to_replace, allowed_filter_attrs)
- connections = query.offset(offset).limit(limit).all()
+ connections = session.scalars(query.offset(offset).limit(limit)).all()
return connection_collection_schema.dump(
ConnectionCollection(connections=connections,
total_entries=total_entries)
)
@@ -125,7 +125,7 @@ def patch_connection(
# If validation get to here, it is extra field validation.
raise BadRequest(detail=str(err.messages))
non_update_fields = ["connection_id", "conn_id"]
- connection =
session.query(Connection).filter_by(conn_id=connection_id).first()
+ connection =
session.scalar(select(Connection).filter_by(conn_id=connection_id).limit(1))
if connection is None:
raise NotFound(
"Connection not found",
@@ -162,8 +162,7 @@ def post_connection(*, session: Session = NEW_SESSION) ->
APIResponse:
helpers.validate_key(conn_id, max_length=200)
except Exception as e:
raise BadRequest(detail=str(e))
- query = session.query(Connection)
- connection = query.filter_by(conn_id=conn_id).first()
+ connection =
session.scalar(select(Connection).filter_by(conn_id=conn_id).limit(1))
if not connection:
connection = Connection(**data)
session.add(connection)
diff --git a/airflow/api_connexion/endpoints/dag_endpoint.py
b/airflow/api_connexion/endpoints/dag_endpoint.py
index 18dcb45be6..55eb971c08 100644
--- a/airflow/api_connexion/endpoints/dag_endpoint.py
+++ b/airflow/api_connexion/endpoints/dag_endpoint.py
@@ -22,6 +22,7 @@ from typing import Collection
from connexion import NoContent
from flask import g, request
from marshmallow import ValidationError
+from sqlalchemy import func, select, update
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import or_
@@ -47,7 +48,7 @@ from airflow.utils.session import NEW_SESSION, provide_session
@provide_session
def get_dag(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse:
"""Get basic information about a DAG."""
- dag = session.query(DagModel).filter(DagModel.dag_id ==
dag_id).one_or_none()
+ dag = session.scalar(select(DagModel).where(DagModel.dag_id == dag_id))
if dag is None:
raise NotFound("DAG not found", detail=f"The DAG with dag_id: {dag_id}
was not found")
@@ -80,27 +81,27 @@ def get_dags(
) -> APIResponse:
"""Get all DAGs."""
allowed_attrs = ["dag_id"]
- dags_query = session.query(DagModel).filter(~DagModel.is_subdag)
+ dags_query = select(DagModel).where(~DagModel.is_subdag)
if only_active:
- dags_query = dags_query.filter(DagModel.is_active)
+ dags_query = dags_query.where(DagModel.is_active)
if paused is not None:
if paused:
- dags_query = dags_query.filter(DagModel.is_paused)
+ dags_query = dags_query.where(DagModel.is_paused)
else:
- dags_query = dags_query.filter(~DagModel.is_paused)
+ dags_query = dags_query.where(~DagModel.is_paused)
if dag_id_pattern:
- dags_query =
dags_query.filter(DagModel.dag_id.ilike(f"%{dag_id_pattern}%"))
+ dags_query =
dags_query.where(DagModel.dag_id.ilike(f"%{dag_id_pattern}%"))
readable_dags =
get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user)
- dags_query = dags_query.filter(DagModel.dag_id.in_(readable_dags))
+ dags_query = dags_query.where(DagModel.dag_id.in_(readable_dags))
if tags:
cond = [DagModel.tags.any(DagTag.name == tag) for tag in tags]
- dags_query = dags_query.filter(or_(*cond))
+ dags_query = dags_query.where(or_(*cond))
- total_entries = dags_query.count()
+ total_entries =
session.scalar(select(func.count()).select_from(dags_query))
dags_query = apply_sorting(dags_query, order_by, {}, allowed_attrs)
- dags = dags_query.offset(offset).limit(limit).all()
+ dags = session.scalars(dags_query.offset(offset).limit(limit)).all()
return dags_collection_schema.dump(DAGCollection(dags=dags,
total_entries=total_entries))
@@ -119,7 +120,7 @@ def patch_dag(*, dag_id: str, update_mask: UpdateMask =
None, session: Session =
raise BadRequest(detail="Only `is_paused` field can be updated
through the REST API")
patch_body_[update_mask[0]] = patch_body[update_mask[0]]
patch_body = patch_body_
- dag = session.query(DagModel).filter(DagModel.dag_id ==
dag_id).one_or_none()
+ dag = session.scalar(select(DagModel).where(DagModel.dag_id == dag_id))
if not dag:
raise NotFound(f"Dag with id: '{dag_id}' not found")
dag.is_paused = patch_body["is_paused"]
@@ -144,27 +145,30 @@ def patch_dags(limit, session, offset=0,
only_active=True, tags=None, dag_id_pat
patch_body_[update_mask] = patch_body[update_mask]
patch_body = patch_body_
if only_active:
- dags_query = session.query(DagModel).filter(~DagModel.is_subdag,
DagModel.is_active)
+ dags_query = select(DagModel).where(~DagModel.is_subdag,
DagModel.is_active)
else:
- dags_query = session.query(DagModel).filter(~DagModel.is_subdag)
+ dags_query = select(DagModel).where(~DagModel.is_subdag)
if dag_id_pattern == "~":
dag_id_pattern = "%"
- dags_query =
dags_query.filter(DagModel.dag_id.ilike(f"%{dag_id_pattern}%"))
+ dags_query = dags_query.where(DagModel.dag_id.ilike(f"%{dag_id_pattern}%"))
editable_dags =
get_airflow_app().appbuilder.sm.get_editable_dag_ids(g.user)
- dags_query = dags_query.filter(DagModel.dag_id.in_(editable_dags))
+ dags_query = dags_query.where(DagModel.dag_id.in_(editable_dags))
if tags:
cond = [DagModel.tags.any(DagTag.name == tag) for tag in tags]
- dags_query = dags_query.filter(or_(*cond))
+ dags_query = dags_query.where(or_(*cond))
- total_entries = dags_query.count()
+ total_entries =
session.scalar(select(func.count()).select_from(dags_query))
- dags =
dags_query.order_by(DagModel.dag_id).offset(offset).limit(limit).all()
+ dags =
session.scalars(dags_query.order_by(DagModel.dag_id).offset(offset).limit(limit)).all()
dags_to_update = {dag.dag_id for dag in dags}
- session.query(DagModel).filter(DagModel.dag_id.in_(dags_to_update)).update(
- {DagModel.is_paused: patch_body["is_paused"]},
synchronize_session="fetch"
+ session.execute(
+ update(DagModel)
+ .where(DagModel.dag_id.in_(dags_to_update))
+ .values(is_paused=patch_body["is_paused"])
+ .execution_options(synchronize_session="fetch")
)
session.flush()
diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py
b/airflow/api_connexion/endpoints/dag_run_endpoint.py
index 95e5913ebd..f62b28273f 100644
--- a/airflow/api_connexion/endpoints/dag_run_endpoint.py
+++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py
@@ -23,8 +23,9 @@ from connexion import NoContent
from flask import g
from flask_login import current_user
from marshmallow import ValidationError
-from sqlalchemy import delete, or_
-from sqlalchemy.orm import Query, Session
+from sqlalchemy import delete, func, or_, select
+from sqlalchemy.orm import Session
+from sqlalchemy.sql import Select
from airflow.api.common.mark_tasks import (
set_dag_run_state_to_failed,
@@ -91,7 +92,7 @@ def delete_dag_run(*, dag_id: str, dag_run_id: str, session:
Session = NEW_SESSI
@provide_session
def get_dag_run(*, dag_id: str, dag_run_id: str, session: Session =
NEW_SESSION) -> APIResponse:
"""Get a DAG Run."""
- dag_run = session.query(DagRun).filter(DagRun.dag_id == dag_id,
DagRun.run_id == dag_run_id).one_or_none()
+ dag_run = session.scalar(select(DagRun).where(DagRun.dag_id == dag_id,
DagRun.run_id == dag_run_id))
if dag_run is None:
raise NotFound(
"DAGRun not found",
@@ -112,13 +113,11 @@ def get_upstream_dataset_events(
*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION
) -> APIResponse:
"""If dag run is dataset-triggered, return the dataset events that
triggered it."""
- dag_run: DagRun | None = (
- session.query(DagRun)
- .filter(
+ dag_run: DagRun | None = session.scalar(
+ select(DagRun).where(
DagRun.dag_id == dag_id,
DagRun.run_id == dag_run_id,
)
- .one_or_none()
)
if dag_run is None:
raise NotFound(
@@ -132,7 +131,7 @@ def get_upstream_dataset_events(
def _fetch_dag_runs(
- query: Query,
+ query: Select,
*,
end_date_gte: str | None,
end_date_lte: str | None,
@@ -145,28 +144,29 @@ def _fetch_dag_runs(
limit: int | None,
offset: int | None,
order_by: str,
+ session: Session,
) -> tuple[list[DagRun], int]:
if start_date_gte:
- query = query.filter(DagRun.start_date >= start_date_gte)
+ query = query.where(DagRun.start_date >= start_date_gte)
if start_date_lte:
- query = query.filter(DagRun.start_date <= start_date_lte)
+ query = query.where(DagRun.start_date <= start_date_lte)
# filter execution date
if execution_date_gte:
- query = query.filter(DagRun.execution_date >= execution_date_gte)
+ query = query.where(DagRun.execution_date >= execution_date_gte)
if execution_date_lte:
- query = query.filter(DagRun.execution_date <= execution_date_lte)
+ query = query.where(DagRun.execution_date <= execution_date_lte)
# filter end date
if end_date_gte:
- query = query.filter(DagRun.end_date >= end_date_gte)
+ query = query.where(DagRun.end_date >= end_date_gte)
if end_date_lte:
- query = query.filter(DagRun.end_date <= end_date_lte)
+ query = query.where(DagRun.end_date <= end_date_lte)
# filter updated at
if updated_at_gte:
- query = query.filter(DagRun.updated_at >= updated_at_gte)
+ query = query.where(DagRun.updated_at >= updated_at_gte)
if updated_at_lte:
- query = query.filter(DagRun.updated_at <= updated_at_lte)
+ query = query.where(DagRun.updated_at <= updated_at_lte)
- total_entries = query.count()
+ total_entries = session.scalar(select(func.count()).select_from(query))
to_replace = {"dag_run_id": "run_id"}
allowed_filter_attrs = [
"id",
@@ -181,7 +181,7 @@ def _fetch_dag_runs(
"conf",
]
query = apply_sorting(query, order_by, to_replace, allowed_filter_attrs)
- return query.offset(offset).limit(limit).all(), total_entries
+ return session.scalars(query.offset(offset).limit(limit)).all(),
total_entries
@security.requires_access(
@@ -222,17 +222,17 @@ def get_dag_runs(
session: Session = NEW_SESSION,
):
"""Get all DAG Runs."""
- query = session.query(DagRun)
+ query = select(DagRun)
# This endpoint allows specifying ~ as the dag_id to retrieve DAG Runs
for all DAGs.
if dag_id == "~":
appbuilder = get_airflow_app().appbuilder
- query =
query.filter(DagRun.dag_id.in_(appbuilder.sm.get_readable_dag_ids(g.user)))
+ query =
query.where(DagRun.dag_id.in_(appbuilder.sm.get_readable_dag_ids(g.user)))
else:
- query = query.filter(DagRun.dag_id == dag_id)
+ query = query.where(DagRun.dag_id == dag_id)
if state:
- query = query.filter(DagRun.state.in_(state))
+ query = query.where(DagRun.state.in_(state))
dag_run, total_entries = _fetch_dag_runs(
query,
@@ -247,6 +247,7 @@ def get_dag_runs(
limit=limit,
offset=offset,
order_by=order_by,
+ session=session,
)
return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_run,
total_entries=total_entries))
@@ -268,16 +269,16 @@ def get_dag_runs_batch(*, session: Session = NEW_SESSION)
-> APIResponse:
appbuilder = get_airflow_app().appbuilder
readable_dag_ids = appbuilder.sm.get_readable_dag_ids(g.user)
- query = session.query(DagRun)
+ query = select(DagRun)
if data.get("dag_ids"):
dag_ids = set(data["dag_ids"]) & set(readable_dag_ids)
- query = query.filter(DagRun.dag_id.in_(dag_ids))
+ query = query.where(DagRun.dag_id.in_(dag_ids))
else:
- query = query.filter(DagRun.dag_id.in_(readable_dag_ids))
+ query = query.where(DagRun.dag_id.in_(readable_dag_ids))
states = data.get("states")
if states:
- query = query.filter(DagRun.state.in_(states))
+ query = query.where(DagRun.state.in_(states))
dag_runs, total_entries = _fetch_dag_runs(
query,
@@ -290,6 +291,7 @@ def get_dag_runs_batch(*, session: Session = NEW_SESSION)
-> APIResponse:
limit=data["page_limit"],
offset=data["page_offset"],
order_by=data.get("order_by", "id"),
+ session=session,
)
return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_runs,
total_entries=total_entries))
@@ -310,7 +312,7 @@ def get_dag_runs_batch(*, session: Session = NEW_SESSION)
-> APIResponse:
)
def post_dag_run(*, dag_id: str, session: Session = NEW_SESSION) ->
APIResponse:
"""Trigger a DAG."""
- dm = session.query(DagModel).filter(DagModel.is_active, DagModel.dag_id ==
dag_id).first()
+ dm = session.scalar(select(DagModel).where(DagModel.is_active,
DagModel.dag_id == dag_id).limit(1))
if not dm:
raise NotFound(title="DAG not found", detail=f"DAG with dag_id:
'{dag_id}' not found")
if dm.has_import_errors:
@@ -325,13 +327,13 @@ def post_dag_run(*, dag_id: str, session: Session =
NEW_SESSION) -> APIResponse:
logical_date = pendulum.instance(post_body["execution_date"])
run_id = post_body["run_id"]
- dagrun_instance = (
- session.query(DagRun)
- .filter(
+ dagrun_instance = session.scalar(
+ select(DagRun)
+ .where(
DagRun.dag_id == dag_id,
or_(DagRun.run_id == run_id, DagRun.execution_date ==
logical_date),
)
- .first()
+ .limit(1)
)
if not dagrun_instance:
try:
@@ -375,8 +377,8 @@ def post_dag_run(*, dag_id: str, session: Session =
NEW_SESSION) -> APIResponse:
@provide_session
def update_dag_run_state(*, dag_id: str, dag_run_id: str, session: Session =
NEW_SESSION) -> APIResponse:
"""Set a state of a dag run."""
- dag_run: DagRun | None = (
- session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id ==
dag_run_id).one_or_none()
+ dag_run: DagRun | None = session.scalar(
+ select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id ==
dag_run_id)
)
if dag_run is None:
error_message = f"Dag Run id {dag_run_id} not found in dag {dag_id}"
@@ -407,8 +409,8 @@ def update_dag_run_state(*, dag_id: str, dag_run_id: str,
session: Session = NEW
@provide_session
def clear_dag_run(*, dag_id: str, dag_run_id: str, session: Session =
NEW_SESSION) -> APIResponse:
"""Clear a dag run."""
- dag_run: DagRun | None = (
- session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id ==
dag_run_id).one_or_none()
+ dag_run: DagRun | None = session.scalar(
+ select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id ==
dag_run_id)
)
if dag_run is None:
error_message = f"Dag Run id {dag_run_id} not found in dag {dag_id}"
@@ -445,7 +447,7 @@ def clear_dag_run(*, dag_id: str, dag_run_id: str, session:
Session = NEW_SESSIO
include_parentdag=True,
only_failed=False,
)
- dag_run = session.query(DagRun).filter(DagRun.id == dag_run.id).one()
+ dag_run = session.execute(select(DagRun).where(DagRun.id ==
dag_run.id)).scalar_one()
return dagrun_schema.dump(dag_run)
@@ -458,8 +460,8 @@ def clear_dag_run(*, dag_id: str, dag_run_id: str, session:
Session = NEW_SESSIO
@provide_session
def set_dag_run_note(*, dag_id: str, dag_run_id: str, session: Session =
NEW_SESSION) -> APIResponse:
"""Set the note for a dag run."""
- dag_run: DagRun | None = (
- session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id ==
dag_run_id).one_or_none()
+ dag_run: DagRun | None = session.scalar(
+ select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id ==
dag_run_id)
)
if dag_run is None:
error_message = f"Dag Run id {dag_run_id} not found in dag {dag_id}"
diff --git a/airflow/api_connexion/endpoints/dag_warning_endpoint.py
b/airflow/api_connexion/endpoints/dag_warning_endpoint.py
index 5a73afd1a3..66aa1184a1 100644
--- a/airflow/api_connexion/endpoints/dag_warning_endpoint.py
+++ b/airflow/api_connexion/endpoints/dag_warning_endpoint.py
@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
+from sqlalchemy import func, select
from sqlalchemy.orm import Session
from airflow.api_connexion import security
@@ -48,14 +49,14 @@ def get_dag_warnings(
:param warning_type: the warning type to optionally filter by
"""
allowed_filter_attrs = ["dag_id", "warning_type", "message", "timestamp"]
- query = session.query(DagWarningModel)
+ query = select(DagWarningModel)
if dag_id:
- query = query.filter(DagWarningModel.dag_id == dag_id)
+ query = query.where(DagWarningModel.dag_id == dag_id)
if warning_type:
- query = query.filter(DagWarningModel.warning_type == warning_type)
- total_entries = query.count()
+ query = query.where(DagWarningModel.warning_type == warning_type)
+ total_entries = session.scalar(select(func.count()).select_from(query))
query = apply_sorting(query=query, order_by=order_by,
allowed_attrs=allowed_filter_attrs)
- dag_warnings = query.offset(offset).limit(limit).all()
+ dag_warnings = session.scalars(query.offset(offset).limit(limit)).all()
return dag_warning_collection_schema.dump(
DagWarningCollection(dag_warnings=dag_warnings,
total_entries=total_entries)
)
diff --git a/airflow/api_connexion/endpoints/dataset_endpoint.py
b/airflow/api_connexion/endpoints/dataset_endpoint.py
index 42e8bb3c36..9f4fa443b9 100644
--- a/airflow/api_connexion/endpoints/dataset_endpoint.py
+++ b/airflow/api_connexion/endpoints/dataset_endpoint.py
@@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations
-from sqlalchemy import func
+from sqlalchemy import func, select
from sqlalchemy.orm import Session, joinedload, subqueryload
from airflow.api_connexion import security
@@ -39,11 +39,10 @@ from airflow.utils.session import NEW_SESSION,
provide_session
@provide_session
def get_dataset(uri: str, session: Session = NEW_SESSION) -> APIResponse:
"""Get a Dataset."""
- dataset = (
- session.query(DatasetModel)
- .filter(DatasetModel.uri == uri)
+ dataset = session.scalar(
+ select(DatasetModel)
+ .where(DatasetModel.uri == uri)
.options(joinedload(DatasetModel.consuming_dags),
joinedload(DatasetModel.producing_tasks))
- .one_or_none()
)
if not dataset:
raise NotFound(
@@ -67,17 +66,16 @@ def get_datasets(
"""Get datasets."""
allowed_attrs = ["id", "uri", "created_at", "updated_at"]
- total_entries = session.query(func.count(DatasetModel.id)).scalar()
- query = session.query(DatasetModel)
+ total_entries = session.scalars(select(func.count(DatasetModel.id))).one()
+ query = select(DatasetModel)
if uri_pattern:
- query = query.filter(DatasetModel.uri.ilike(f"%{uri_pattern}%"))
+ query = query.where(DatasetModel.uri.ilike(f"%{uri_pattern}%"))
query = apply_sorting(query, order_by, {}, allowed_attrs)
- datasets = (
+ datasets = session.scalars(
query.options(subqueryload(DatasetModel.consuming_dags),
subqueryload(DatasetModel.producing_tasks))
.offset(offset)
.limit(limit)
- .all()
- )
+ ).all()
return dataset_collection_schema.dump(DatasetCollection(datasets=datasets,
total_entries=total_entries))
@@ -99,24 +97,24 @@ def get_dataset_events(
"""Get dataset events."""
allowed_attrs = ["source_dag_id", "source_task_id", "source_run_id",
"source_map_index", "timestamp"]
- query = session.query(DatasetEvent)
+ query = select(DatasetEvent)
if dataset_id:
- query = query.filter(DatasetEvent.dataset_id == dataset_id)
+ query = query.where(DatasetEvent.dataset_id == dataset_id)
if source_dag_id:
- query = query.filter(DatasetEvent.source_dag_id == source_dag_id)
+ query = query.where(DatasetEvent.source_dag_id == source_dag_id)
if source_task_id:
- query = query.filter(DatasetEvent.source_task_id == source_task_id)
+ query = query.where(DatasetEvent.source_task_id == source_task_id)
if source_run_id:
- query = query.filter(DatasetEvent.source_run_id == source_run_id)
+ query = query.where(DatasetEvent.source_run_id == source_run_id)
if source_map_index:
- query = query.filter(DatasetEvent.source_map_index == source_map_index)
+ query = query.where(DatasetEvent.source_map_index == source_map_index)
query = query.options(subqueryload(DatasetEvent.created_dagruns))
- total_entries = query.count()
+ total_entries = session.scalar(select(func.count()).select_from(query))
query = apply_sorting(query, order_by, {}, allowed_attrs)
- events = query.offset(offset).limit(limit).all()
+ events = session.scalars(query.offset(offset).limit(limit)).all()
return dataset_event_collection_schema.dump(
DatasetEventCollection(dataset_events=events,
total_entries=total_entries)
)
diff --git a/airflow/api_connexion/endpoints/event_log_endpoint.py
b/airflow/api_connexion/endpoints/event_log_endpoint.py
index 335886189c..28615f6fec 100644
--- a/airflow/api_connexion/endpoints/event_log_endpoint.py
+++ b/airflow/api_connexion/endpoints/event_log_endpoint.py
@@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations
-from sqlalchemy import func
+from sqlalchemy import func, select
from sqlalchemy.orm import Session
from airflow.api_connexion import security
@@ -65,10 +65,10 @@ def get_event_logs(
"owner",
"extra",
]
- total_entries = session.query(func.count(Log.id)).scalar()
- query = session.query(Log)
+ total_entries = session.scalars(func.count(Log.id)).one()
+ query = select(Log)
query = apply_sorting(query, order_by, to_replace, allowed_filter_attrs)
- event_logs = query.offset(offset).limit(limit).all()
+ event_logs = session.scalars(query.offset(offset).limit(limit)).all()
return event_log_collection_schema.dump(
EventLogCollection(event_logs=event_logs, total_entries=total_entries)
)
diff --git a/airflow/api_connexion/endpoints/extra_link_endpoint.py
b/airflow/api_connexion/endpoints/extra_link_endpoint.py
index 2b12667e7c..e28822522e 100644
--- a/airflow/api_connexion/endpoints/extra_link_endpoint.py
+++ b/airflow/api_connexion/endpoints/extra_link_endpoint.py
@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
+from sqlalchemy import select
from sqlalchemy.orm.session import Session
from airflow import DAG
@@ -57,14 +58,12 @@ def get_extra_links(
except TaskNotFound:
raise NotFound("Task not found", detail=f'Task with ID = "{task_id}"
not found')
- ti = (
- session.query(TaskInstance)
- .filter(
+ ti = session.scalar(
+ select(TaskInstance).where(
TaskInstance.dag_id == dag_id,
TaskInstance.run_id == dag_run_id,
TaskInstance.task_id == task_id,
)
- .one_or_none()
)
if not ti:
diff --git a/airflow/api_connexion/endpoints/import_error_endpoint.py
b/airflow/api_connexion/endpoints/import_error_endpoint.py
index 3ffd8f11a5..5e871e08bd 100644
--- a/airflow/api_connexion/endpoints/import_error_endpoint.py
+++ b/airflow/api_connexion/endpoints/import_error_endpoint.py
@@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations
-from sqlalchemy import func
+from sqlalchemy import func, select
from sqlalchemy.orm import Session
from airflow.api_connexion import security
@@ -60,10 +60,10 @@ def get_import_errors(
"""Get all import errors."""
to_replace = {"import_error_id": "id"}
allowed_filter_attrs = ["import_error_id", "timestamp", "filename"]
- total_entries = session.query(func.count(ImportErrorModel.id)).scalar()
- query = session.query(ImportErrorModel)
+ total_entries = session.scalars(func.count(ImportErrorModel.id)).one()
+ query = select(ImportErrorModel)
query = apply_sorting(query, order_by, to_replace, allowed_filter_attrs)
- import_errors = query.offset(offset).limit(limit).all()
+ import_errors = session.scalars(query.offset(offset).limit(limit)).all()
return import_error_collection_schema.dump(
ImportErrorCollection(import_errors=import_errors,
total_entries=total_entries)
)
diff --git a/airflow/api_connexion/endpoints/log_endpoint.py
b/airflow/api_connexion/endpoints/log_endpoint.py
index a81258ce9d..8df712e67c 100644
--- a/airflow/api_connexion/endpoints/log_endpoint.py
+++ b/airflow/api_connexion/endpoints/log_endpoint.py
@@ -21,6 +21,7 @@ from typing import Any
from flask import Response, request
from itsdangerous.exc import BadSignature
from itsdangerous.url_safe import URLSafeSerializer
+from sqlalchemy import select
from sqlalchemy.orm import Session, joinedload
from airflow.api_connexion import security
@@ -28,7 +29,7 @@ from airflow.api_connexion.exceptions import BadRequest,
NotFound
from airflow.api_connexion.schemas.log_schema import LogResponseObject,
logs_schema
from airflow.api_connexion.types import APIResponse
from airflow.exceptions import TaskNotFound
-from airflow.models import TaskInstance
+from airflow.models import TaskInstance, Trigger
from airflow.security import permissions
from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.log.log_reader import TaskLogReader
@@ -77,18 +78,17 @@ def get_log(
if not task_log_reader.supports_read:
raise BadRequest("Task log handler does not support read logs.")
query = (
- session.query(TaskInstance)
- .filter(
+ select(TaskInstance)
+ .where(
TaskInstance.task_id == task_id,
TaskInstance.dag_id == dag_id,
TaskInstance.run_id == dag_run_id,
TaskInstance.map_index == map_index,
)
.join(TaskInstance.dag_run)
- .options(joinedload("trigger"))
- .options(joinedload("trigger.triggerer_job"))
+
.options(joinedload(TaskInstance.trigger).joinedload(Trigger.triggerer_job))
)
- ti = query.one_or_none()
+ ti = session.scalar(query)
if ti is None:
metadata["end_of_log"] = True
raise NotFound(title="TaskInstance not found")
diff --git a/airflow/api_connexion/endpoints/pool_endpoint.py
b/airflow/api_connexion/endpoints/pool_endpoint.py
index a760ee4d83..3668741c1e 100644
--- a/airflow/api_connexion/endpoints/pool_endpoint.py
+++ b/airflow/api_connexion/endpoints/pool_endpoint.py
@@ -20,7 +20,7 @@ from http import HTTPStatus
from flask import Response
from marshmallow import ValidationError
-from sqlalchemy import delete, func
+from sqlalchemy import delete, func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
@@ -52,7 +52,7 @@ def delete_pool(*, pool_name: str, session: Session =
NEW_SESSION) -> APIRespons
@provide_session
def get_pool(*, pool_name: str, session: Session = NEW_SESSION) -> APIResponse:
"""Get a pool."""
- obj = session.query(Pool).filter(Pool.pool == pool_name).one_or_none()
+ obj = session.scalar(select(Pool).where(Pool.pool == pool_name))
if obj is None:
raise NotFound(detail=f"Pool with name:'{pool_name}' not found")
return pool_schema.dump(obj)
@@ -71,10 +71,10 @@ def get_pools(
"""Get all pools."""
to_replace = {"name": "pool"}
allowed_filter_attrs = ["name", "slots", "id"]
- total_entries = session.query(func.count(Pool.id)).scalar()
- query = session.query(Pool)
+ total_entries = session.scalars(func.count(Pool.id)).one()
+ query = select(Pool)
query = apply_sorting(query, order_by, to_replace, allowed_filter_attrs)
- pools = query.offset(offset).limit(limit).all()
+ pools = session.scalars(query.offset(offset).limit(limit)).all()
return pool_collection_schema.dump(PoolCollection(pools=pools,
total_entries=total_entries))
@@ -98,7 +98,7 @@ def patch_pool(
except KeyError:
pass
- pool = session.query(Pool).filter(Pool.pool == pool_name).first()
+ pool = session.scalar(select(Pool).where(Pool.pool == pool_name).limit(1))
if not pool:
raise NotFound(detail=f"Pool with name:'{pool_name}' not found")
diff --git a/airflow/api_connexion/endpoints/role_and_permission_endpoint.py
b/airflow/api_connexion/endpoints/role_and_permission_endpoint.py
index 4ed40caae5..609c45893f 100644
--- a/airflow/api_connexion/endpoints/role_and_permission_endpoint.py
+++ b/airflow/api_connexion/endpoints/role_and_permission_endpoint.py
@@ -21,7 +21,7 @@ from http import HTTPStatus
from connexion import NoContent
from flask import request
from marshmallow import ValidationError
-from sqlalchemy import asc, desc, func
+from sqlalchemy import asc, desc, func, select
from airflow.api_connexion import security
from airflow.api_connexion.exceptions import AlreadyExists, BadRequest,
NotFound
@@ -69,7 +69,7 @@ def get_roles(*, order_by: str = "name", limit: int, offset:
int | None = None)
"""Get roles."""
appbuilder = get_airflow_app().appbuilder
session = appbuilder.get_session
- total_entries = session.query(func.count(Role.id)).scalar()
+ total_entries = session.scalars(select(func.count(Role.id))).one()
direction = desc if order_by.startswith("-") else asc
to_replace = {"role_id": "id"}
order_param = order_by.strip("-")
@@ -81,8 +81,12 @@ def get_roles(*, order_by: str = "name", limit: int, offset:
int | None = None)
f"the attribute does not exist on the model"
)
- query = session.query(Role)
- roles = query.order_by(direction(getattr(Role,
order_param))).offset(offset).limit(limit).all()
+ query = select(Role)
+ roles = (
+ session.scalars(query.order_by(direction(getattr(Role,
order_param))).offset(offset).limit(limit))
+ .unique()
+ .all()
+ )
return role_collection_schema.dump(RoleCollection(roles=roles,
total_entries=total_entries))
@@ -92,9 +96,9 @@ def get_roles(*, order_by: str = "name", limit: int, offset:
int | None = None)
def get_permissions(*, limit: int, offset: int | None = None) -> APIResponse:
"""Get permissions."""
session = get_airflow_app().appbuilder.get_session
- total_entries = session.query(func.count(Action.id)).scalar()
- query = session.query(Action)
- actions = query.offset(offset).limit(limit).all()
+ total_entries = session.scalars(select(func.count(Action.id))).one()
+ query = select(Action)
+ actions = session.scalars(query.offset(offset).limit(limit)).all()
return action_collection_schema.dump(ActionCollection(actions=actions,
total_entries=total_entries))
diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py
b/airflow/api_connexion/endpoints/task_instance_endpoint.py
index 533d97c858..3028b0bb73 100644
--- a/airflow/api_connexion/endpoints/task_instance_endpoint.py
+++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py
@@ -19,11 +19,10 @@ from __future__ import annotations
from typing import Any, Iterable, TypeVar
from marshmallow import ValidationError
-from sqlalchemy import and_, func, or_
+from sqlalchemy import and_, func, or_, select
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.orm import Session, joinedload
-from sqlalchemy.orm.query import Query
-from sqlalchemy.sql import ClauseElement
+from sqlalchemy.sql import ClauseElement, Select
from airflow.api_connexion import security
from airflow.api_connexion.endpoints.request_dict import get_json_request_dict
@@ -72,8 +71,8 @@ def get_task_instance(
) -> APIResponse:
"""Get task instance."""
query = (
- session.query(TI)
- .filter(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id ==
task_id)
+ select(TI)
+ .where(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id ==
task_id)
.join(TI.dag_run)
.outerjoin(
SlaMiss,
@@ -83,12 +82,12 @@ def get_task_instance(
SlaMiss.task_id == TI.task_id,
),
)
- .add_entity(SlaMiss)
+ .add_columns(SlaMiss)
.options(joinedload(TI.rendered_task_instance_fields))
)
try:
- task_instance = query.one_or_none()
+ task_instance = session.execute(query).one_or_none()
except MultipleResultsFound:
raise NotFound(
"Task instance not found", detail="Task instance is mapped, add
the map_index value to the URL"
@@ -121,10 +120,8 @@ def get_mapped_task_instance(
) -> APIResponse:
"""Get task instance."""
query = (
- session.query(TI)
- .filter(
- TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id ==
task_id, TI.map_index == map_index
- )
+ select(TI)
+ .where(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id ==
task_id, TI.map_index == map_index)
.join(TI.dag_run)
.outerjoin(
SlaMiss,
@@ -134,10 +131,11 @@ def get_mapped_task_instance(
SlaMiss.task_id == TI.task_id,
),
)
- .add_entity(SlaMiss)
+ .add_columns(SlaMiss)
.options(joinedload(TI.rendered_task_instance_fields))
)
- task_instance = query.one_or_none()
+ task_instance = session.execute(query).one_or_none()
+
if task_instance is None:
raise NotFound("Task instance not found")
@@ -192,13 +190,14 @@ def get_mapped_task_instances(
states = _convert_state(state)
base_query = (
- session.query(TI)
- .filter(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id ==
task_id, TI.map_index >= 0)
+ select(TI)
+ .where(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id ==
task_id, TI.map_index >= 0)
.join(TI.dag_run)
)
# 0 can mean a mapped TI that expanded to an empty list, so it is not an
automatic 404
- if base_query.with_entities(func.count("*")).scalar() == 0:
+ unfiltered_total_count =
session.execute(select(func.count("*")).select_from(base_query)).scalar()
+ if unfiltered_total_count == 0:
dag = get_airflow_app().dag_bag.get_dag(dag_id)
if not dag:
error_message = f"DAG {dag_id} not found"
@@ -212,50 +211,54 @@ def get_mapped_task_instances(
raise NotFound(error_message)
# Other search criteria
- query = _apply_range_filter(
+ base_query = _apply_range_filter(
base_query,
key=DR.execution_date,
value_range=(execution_date_gte, execution_date_lte),
)
- query = _apply_range_filter(query, key=TI.start_date,
value_range=(start_date_gte, start_date_lte))
- query = _apply_range_filter(query, key=TI.end_date,
value_range=(end_date_gte, end_date_lte))
- query = _apply_range_filter(query, key=TI.duration,
value_range=(duration_gte, duration_lte))
- query = _apply_range_filter(query, key=TI.updated_at,
value_range=(updated_at_gte, updated_at_lte))
- query = _apply_array_filter(query, key=TI.state, values=states)
- query = _apply_array_filter(query, key=TI.pool, values=pool)
- query = _apply_array_filter(query, key=TI.queue, values=queue)
+ base_query = _apply_range_filter(
+ base_query, key=TI.start_date, value_range=(start_date_gte,
start_date_lte)
+ )
+ base_query = _apply_range_filter(base_query, key=TI.end_date,
value_range=(end_date_gte, end_date_lte))
+ base_query = _apply_range_filter(base_query, key=TI.duration,
value_range=(duration_gte, duration_lte))
+ base_query = _apply_range_filter(
+ base_query, key=TI.updated_at, value_range=(updated_at_gte,
updated_at_lte)
+ )
+ base_query = _apply_array_filter(base_query, key=TI.state, values=states)
+ base_query = _apply_array_filter(base_query, key=TI.pool, values=pool)
+ base_query = _apply_array_filter(base_query, key=TI.queue, values=queue)
# Count elements before joining extra columns
- total_entries = query.with_entities(func.count("*")).scalar()
+ total_entries =
session.execute(select(func.count("*")).select_from(base_query)).scalar()
# Add SLA miss
- query = (
- query.join(
+ entry_query = (
+ base_query.outerjoin(
SlaMiss,
and_(
SlaMiss.dag_id == TI.dag_id,
SlaMiss.task_id == TI.task_id,
SlaMiss.execution_date == DR.execution_date,
),
- isouter=True,
)
- .add_entity(SlaMiss)
+ .add_columns(SlaMiss)
.options(joinedload(TI.rendered_task_instance_fields))
)
if order_by:
if order_by == "state":
- query = query.order_by(TI.state.asc(), TI.map_index.asc())
+ entry_query = entry_query.order_by(TI.state.asc(),
TI.map_index.asc())
elif order_by == "-state":
- query = query.order_by(TI.state.desc(), TI.map_index.asc())
+ entry_query = entry_query.order_by(TI.state.desc(),
TI.map_index.asc())
elif order_by == "-map_index":
- query = query.order_by(TI.map_index.desc())
+ entry_query = entry_query.order_by(TI.map_index.desc())
else:
raise BadRequest(detail=f"Ordering with '{order_by}' is not
supported")
else:
- query = query.order_by(TI.map_index.asc())
+ entry_query = entry_query.order_by(TI.map_index.asc())
- task_instances = query.offset(offset).limit(limit).all()
+ # using execute because we want the SlaMiss entity. Scalars don't return
None for missing entities
+ task_instances =
session.execute(entry_query.offset(offset).limit(limit)).all()
return task_instance_collection_schema.dump(
TaskInstanceCollection(task_instances=task_instances,
total_entries=total_entries)
)
@@ -267,19 +270,19 @@ def _convert_state(states: Iterable[str] | None) ->
list[str | None] | None:
return [State.NONE if s == "none" else s for s in states]
-def _apply_array_filter(query: Query, key: ClauseElement, values:
Iterable[Any] | None) -> Query:
+def _apply_array_filter(query: Select, key: ClauseElement, values:
Iterable[Any] | None) -> Select:
if values is not None:
cond = ((key == v) for v in values)
- query = query.filter(or_(*cond))
+ query = query.where(or_(*cond))
return query
-def _apply_range_filter(query: Query, key: ClauseElement, value_range:
tuple[T, T]) -> Query:
+def _apply_range_filter(query: Select, key: ClauseElement, value_range:
tuple[T, T]) -> Select:
gte_value, lte_value = value_range
if gte_value is not None:
- query = query.filter(key >= gte_value)
+ query = query.where(key >= gte_value)
if lte_value is not None:
- query = query.filter(key <= lte_value)
+ query = query.where(key <= lte_value)
return query
@@ -328,12 +331,12 @@ def get_task_instances(
# Because state can be 'none'
states = _convert_state(state)
- base_query = session.query(TI).join(TI.dag_run)
+ base_query = select(TI).join(TI.dag_run)
if dag_id != "~":
- base_query = base_query.filter(TI.dag_id == dag_id)
+ base_query = base_query.where(TI.dag_id == dag_id)
if dag_run_id != "~":
- base_query = base_query.filter(TI.run_id == dag_run_id)
+ base_query = base_query.where(TI.run_id == dag_run_id)
base_query = _apply_range_filter(
base_query,
key=DR.execution_date,
@@ -352,22 +355,26 @@ def get_task_instances(
base_query = _apply_array_filter(base_query, key=TI.queue, values=queue)
# Count elements before joining extra columns
- total_entries = base_query.with_entities(func.count("*")).scalar()
+ count_query = select(func.count("*")).select_from(base_query)
+ total_entries = session.execute(count_query).scalar()
+
# Add join
- query = (
- base_query.join(
+ entry_query = (
+ base_query.outerjoin(
SlaMiss,
and_(
SlaMiss.dag_id == TI.dag_id,
SlaMiss.task_id == TI.task_id,
SlaMiss.execution_date == DR.execution_date,
),
- isouter=True,
)
- .add_entity(SlaMiss)
+ .add_columns(SlaMiss)
.options(joinedload(TI.rendered_task_instance_fields))
+ .offset(offset)
+ .limit(limit)
)
- task_instances = query.offset(offset).limit(limit).all()
+ # using execute because we want the SlaMiss entity. Scalars don't return
None for missing entities
+ task_instances = session.execute(entry_query).all()
return task_instance_collection_schema.dump(
TaskInstanceCollection(task_instances=task_instances,
total_entries=total_entries)
)
@@ -389,7 +396,7 @@ def get_task_instances_batch(session: Session =
NEW_SESSION) -> APIResponse:
except ValidationError as err:
raise BadRequest(detail=str(err.messages))
states = _convert_state(data["state"])
- base_query = session.query(TI).join(TI.dag_run)
+ base_query = select(TI).join(TI.dag_run)
base_query = _apply_array_filter(base_query, key=TI.dag_id,
values=data["dag_ids"])
base_query = _apply_range_filter(
@@ -413,7 +420,7 @@ def get_task_instances_batch(session: Session =
NEW_SESSION) -> APIResponse:
base_query = _apply_array_filter(base_query, key=TI.queue,
values=data["queue"])
# Count elements before joining extra columns
- total_entries = base_query.with_entities(func.count("*")).scalar()
+ total_entries =
session.execute(select(func.count("*")).select_from(base_query)).scalar()
# Add join
base_query = base_query.join(
SlaMiss,
@@ -423,9 +430,10 @@ def get_task_instances_batch(session: Session =
NEW_SESSION) -> APIResponse:
SlaMiss.execution_date == DR.execution_date,
),
isouter=True,
- ).add_entity(SlaMiss)
+ ).add_columns(SlaMiss)
ti_query = base_query.options(joinedload(TI.rendered_task_instance_fields))
- task_instances = ti_query.all()
+ # using execute because we want the SlaMiss entity. Scalars don't return
None for missing entities
+ task_instances = session.execute(ti_query).all()
return task_instance_collection_schema.dump(
TaskInstanceCollection(task_instances=task_instances,
total_entries=total_entries)
@@ -461,9 +469,7 @@ def post_clear_task_instances(*, dag_id: str, session:
Session = NEW_SESSION) ->
downstream = data.pop("include_downstream", False)
upstream = data.pop("include_upstream", False)
if dag_run_id is not None:
- dag_run: DR | None = (
- session.query(DR).filter(DR.dag_id == dag_id, DR.run_id ==
dag_run_id).one_or_none()
- )
+ dag_run: DR | None = session.scalar(select(DR).where(DR.dag_id ==
dag_id, DR.run_id == dag_run_id))
if dag_run is None:
error_message = f"Dag Run id {dag_run_id} not found in dag
{dag_id}"
raise NotFound(error_message)
@@ -486,16 +492,17 @@ def post_clear_task_instances(*, dag_id: str, session:
Session = NEW_SESSION) ->
# If we had upstream/downstream etc then also include those!
task_ids.extend(tid for tid in dag.task_dict if tid != task_id)
task_instances = dag.clear(dry_run=True,
dag_bag=get_airflow_app().dag_bag, task_ids=task_ids, **data)
+
if not dry_run:
clear_task_instances(
- task_instances.all(),
+ task_instances,
session,
dag=dag,
dag_run_state=DagRunState.QUEUED if reset_dag_runs else False,
)
return task_instance_reference_collection_schema.dump(
- TaskInstanceReferenceCollection(task_instances=task_instances.all())
+ TaskInstanceReferenceCollection(task_instances=task_instances)
)
@@ -532,9 +539,11 @@ def post_set_task_instances_state(*, dag_id: str, session:
Session = NEW_SESSION
if (
execution_date
and (
- session.query(TI)
- .filter(TI.task_id == task_id, TI.dag_id == dag_id,
TI.execution_date == execution_date)
- .one_or_none()
+ session.scalars(
+ select(TI).where(
+ TI.task_id == task_id, TI.dag_id == dag_id,
TI.execution_date == execution_date
+ )
+ ).one_or_none()
)
is None
):
@@ -653,8 +662,8 @@ def set_task_instance_note(
raise BadRequest(detail=str(err))
query = (
- session.query(TI)
- .filter(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id ==
task_id)
+ select(TI)
+ .where(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id ==
task_id)
.join(TI.dag_run)
.outerjoin(
SlaMiss,
@@ -664,16 +673,16 @@ def set_task_instance_note(
SlaMiss.task_id == TI.task_id,
),
)
- .add_entity(SlaMiss)
+ .add_columns(SlaMiss)
.options(joinedload(TI.rendered_task_instance_fields))
)
if map_index == -1:
- query = query.filter(or_(TI.map_index == -1, TI.map_index is None))
+ query = query.where(or_(TI.map_index == -1, TI.map_index is None))
else:
- query = query.filter(TI.map_index == map_index)
+ query = query.where(TI.map_index == map_index)
try:
- result = query.one_or_none()
+ result = session.execute(query).one_or_none()
except MultipleResultsFound:
raise NotFound(
"Task instance not found", detail="Task instance is mapped, add
the map_index value to the URL"
diff --git a/airflow/api_connexion/endpoints/user_endpoint.py
b/airflow/api_connexion/endpoints/user_endpoint.py
index a9482701a4..2a88fb1b24 100644
--- a/airflow/api_connexion/endpoints/user_endpoint.py
+++ b/airflow/api_connexion/endpoints/user_endpoint.py
@@ -21,7 +21,7 @@ from http import HTTPStatus
from connexion import NoContent
from flask import request
from marshmallow import ValidationError
-from sqlalchemy import asc, desc, func
+from sqlalchemy import asc, desc, func, select
from werkzeug.security import generate_password_hash
from airflow.api_connexion import security
@@ -55,7 +55,7 @@ def get_users(*, limit: int, order_by: str = "id", offset:
str | None = None) ->
"""Get users."""
appbuilder = get_airflow_app().appbuilder
session = appbuilder.get_session
- total_entries = session.query(func.count(User.id)).scalar()
+ total_entries = session.execute(select(func.count(User.id))).scalar()
direction = desc if order_by.startswith("-") else asc
to_replace = {"user_id": "id"}
order_param = order_by.strip("-")
@@ -75,8 +75,8 @@ def get_users(*, limit: int, order_by: str = "id", offset:
str | None = None) ->
f"the attribute does not exist on the model"
)
- query = session.query(User)
- users = query.order_by(direction(getattr(User,
order_param))).offset(offset).limit(limit).all()
+ query = select(User).order_by(direction(getattr(User,
order_param))).offset(offset).limit(limit)
+ users = session.scalars(query).all()
return user_collection_schema.dump(UserCollection(users=users,
total_entries=total_entries))
diff --git a/airflow/api_connexion/endpoints/variable_endpoint.py
b/airflow/api_connexion/endpoints/variable_endpoint.py
index da8f35fcb8..61a1871104 100644
--- a/airflow/api_connexion/endpoints/variable_endpoint.py
+++ b/airflow/api_connexion/endpoints/variable_endpoint.py
@@ -20,7 +20,7 @@ from http import HTTPStatus
from flask import Response
from marshmallow import ValidationError
-from sqlalchemy import func
+from sqlalchemy import func, select
from sqlalchemy.orm import Session
from airflow.api_connexion import security
@@ -57,10 +57,10 @@ def delete_variable(*, variable_key: str) -> Response:
@provide_session
def get_variable(*, variable_key: str, session: Session = NEW_SESSION) ->
Response:
"""Get a variable by key."""
- var = session.query(Variable).filter(Variable.key == variable_key)
- if not var.count():
+ var = session.scalar(select(Variable).where(Variable.key ==
variable_key).limit(1))
+ if not var:
raise NotFound("Variable not found")
- return variable_schema.dump(var.first())
+ return variable_schema.dump(var)
@security.requires_access([(permissions.ACTION_CAN_READ,
permissions.RESOURCE_VARIABLE)])
@@ -74,12 +74,12 @@ def get_variables(
session: Session = NEW_SESSION,
) -> Response:
"""Get all variable values."""
- total_entries = session.query(func.count(Variable.id)).scalar()
+ total_entries = session.execute(select(func.count(Variable.id))).scalar()
to_replace = {"value": "val"}
allowed_filter_attrs = ["value", "key", "id"]
- query = session.query(Variable)
+ query = select(Variable)
query = apply_sorting(query, order_by, to_replace, allowed_filter_attrs)
- variables = query.offset(offset).limit(limit).all()
+ variables = session.scalars(query.offset(offset).limit(limit)).all()
return variable_collection_schema.dump(
{
"variables": variables,
@@ -111,7 +111,7 @@ def patch_variable(
if data["key"] != variable_key:
raise BadRequest("Invalid post body", detail="key from request body
doesn't match uri parameter")
non_update_fields = ["key"]
- variable = session.query(Variable).filter_by(key=variable_key).first()
+ variable =
session.scalar(select(Variable).filter_by(key=variable_key).limit(1))
if update_mask:
data = extract_update_mask_data(update_mask, non_update_fields, data)
for key, val in data.items():
diff --git a/airflow/api_connexion/endpoints/xcom_endpoint.py
b/airflow/api_connexion/endpoints/xcom_endpoint.py
index 2ab5ec26f5..830cedb51c 100644
--- a/airflow/api_connexion/endpoints/xcom_endpoint.py
+++ b/airflow/api_connexion/endpoints/xcom_endpoint.py
@@ -19,7 +19,7 @@ from __future__ import annotations
import copy
from flask import g
-from sqlalchemy import and_
+from sqlalchemy import and_, func, select
from sqlalchemy.orm import Session
from airflow.api_connexion import security
@@ -53,23 +53,23 @@ def get_xcom_entries(
session: Session = NEW_SESSION,
) -> APIResponse:
"""Get all XCom values."""
- query = session.query(XCom)
+ query = select(XCom)
if dag_id == "~":
appbuilder = get_airflow_app().appbuilder
readable_dag_ids = appbuilder.sm.get_readable_dag_ids(g.user)
- query = query.filter(XCom.dag_id.in_(readable_dag_ids))
+ query = query.where(XCom.dag_id.in_(readable_dag_ids))
query = query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.run_id ==
DR.run_id))
else:
- query = query.filter(XCom.dag_id == dag_id)
+ query = query.where(XCom.dag_id == dag_id)
query = query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.run_id ==
DR.run_id))
if task_id != "~":
- query = query.filter(XCom.task_id == task_id)
+ query = query.where(XCom.task_id == task_id)
if dag_run_id != "~":
- query = query.filter(DR.run_id == dag_run_id)
+ query = query.where(DR.run_id == dag_run_id)
query = query.order_by(DR.execution_date, XCom.task_id, XCom.dag_id,
XCom.key)
- total_entries = query.count()
- query = query.offset(offset).limit(limit)
+ total_entries =
session.execute(select(func.count()).select_from(query)).scalar()
+ query = session.scalars(query.offset(offset).limit(limit))
return xcom_collection_schema.dump(XComCollection(xcom_entries=query,
total_entries=total_entries))
@@ -93,15 +93,19 @@ def get_xcom_entry(
) -> APIResponse:
"""Get an XCom entry."""
if deserialize:
- query = session.query(XCom, XCom.value)
+ query = select(XCom, XCom.value)
else:
- query = session.query(XCom)
+ query = select(XCom)
- query = query.filter(XCom.dag_id == dag_id, XCom.task_id == task_id,
XCom.key == xcom_key)
+ query = query.where(XCom.dag_id == dag_id, XCom.task_id == task_id,
XCom.key == xcom_key)
query = query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.run_id ==
DR.run_id))
- query = query.filter(DR.run_id == dag_run_id)
+ query = query.where(DR.run_id == dag_run_id)
+
+ if deserialize:
+ item = session.execute(query).one_or_none()
+ else:
+ item = session.scalars(query).one_or_none()
- item = query.one_or_none()
if item is None:
raise NotFound("XCom entry not found")
diff --git a/airflow/api_connexion/parameters.py
b/airflow/api_connexion/parameters.py
index 5f8c8ad360..f4f55cfecd 100644
--- a/airflow/api_connexion/parameters.py
+++ b/airflow/api_connexion/parameters.py
@@ -23,7 +23,7 @@ from typing import Any, Callable, Container, TypeVar, cast
from pendulum.parsing import ParserError
from sqlalchemy import text
-from sqlalchemy.orm.query import Query
+from sqlalchemy.sql import Select
from airflow.api_connexion.exceptions import BadRequest
from airflow.configuration import conf
@@ -106,11 +106,11 @@ def format_parameters(params_formatters: dict[str,
Callable[[Any], Any]]) -> Cal
def apply_sorting(
- query: Query,
+ query: Select,
order_by: str,
to_replace: dict[str, str] | None = None,
allowed_attrs: Container[str] | None = None,
-) -> Query:
+) -> Select:
"""Apply sorting to query."""
lstriped_orderby = order_by.lstrip("-")
if allowed_attrs and lstriped_orderby not in allowed_attrs:
diff --git a/airflow/models/abstractoperator.py
b/airflow/models/abstractoperator.py
index ff4f5c4140..73d9204aab 100644
--- a/airflow/models/abstractoperator.py
+++ b/airflow/models/abstractoperator.py
@@ -22,6 +22,8 @@ import inspect
from functools import cached_property
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Collection,
Iterable, Iterator, Sequence
+from sqlalchemy import select
+
from airflow.compat.functools import cache
from airflow.configuration import conf
from airflow.exceptions import AirflowException
@@ -549,17 +551,15 @@ class AbstractOperator(Templater, DAGNode):
total_length = None
state: TaskInstanceState | None = None
- unmapped_ti: TaskInstance | None = (
- session.query(TaskInstance)
- .filter(
+ unmapped_ti: TaskInstance | None = session.scalars(
+ select(TaskInstance).where(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id == self.task_id,
TaskInstance.run_id == run_id,
TaskInstance.map_index == -1,
or_(TaskInstance.state.in_(State.unfinished),
TaskInstance.state.is_(None)),
)
- .one_or_none()
- )
+ ).one_or_none()
all_expanded_tis: list[TaskInstance] = []
@@ -582,14 +582,14 @@ class AbstractOperator(Templater, DAGNode):
unmapped_ti.state = TaskInstanceState.SKIPPED
else:
zero_index_ti_exists = (
- session.query(TaskInstance)
- .filter(
- TaskInstance.dag_id == self.dag_id,
- TaskInstance.task_id == self.task_id,
- TaskInstance.run_id == run_id,
- TaskInstance.map_index == 0,
+ session.scalar(
+ select(func.count(TaskInstance.task_id)).where(
+ TaskInstance.dag_id == self.dag_id,
+ TaskInstance.task_id == self.task_id,
+ TaskInstance.run_id == run_id,
+ TaskInstance.map_index == 0,
+ )
)
- .count()
> 0
)
if not zero_index_ti_exists:
@@ -609,14 +609,12 @@ class AbstractOperator(Templater, DAGNode):
indexes_to_map: Iterable[int] = ()
else:
# Only create "missing" ones.
- current_max_mapping = (
- session.query(func.max(TaskInstance.map_index))
- .filter(
+ current_max_mapping = session.scalar(
+ select(func.max(TaskInstance.map_index)).where(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id == self.task_id,
TaskInstance.run_id == run_id,
)
- .scalar()
)
indexes_to_map = range(current_max_mapping + 1, total_length)
@@ -635,13 +633,14 @@ class AbstractOperator(Templater, DAGNode):
# Any (old) task instances with inapplicable indexes (>= the total
# number we need) are set to "REMOVED".
- query = session.query(TaskInstance).filter(
+ query = select(TaskInstance).where(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id == self.task_id,
TaskInstance.run_id == run_id,
TaskInstance.map_index >= total_expanded_ti_count,
)
- to_update = with_row_locks(query, of=TaskInstance, session=session,
**skip_locked(session=session))
+ query = with_row_locks(query, of=TaskInstance, session=session,
**skip_locked(session=session))
+ to_update = session.scalars(query)
for ti in to_update:
ti.state = TaskInstanceState.REMOVED
session.flush()
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 3d0812b62a..12f4206fcf 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -49,6 +49,7 @@ from typing import (
import attr
import pendulum
from dateutil.relativedelta import relativedelta
+from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm.exc import NoResultFound
@@ -1256,12 +1257,12 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
Clears the state of task instances associated with the task, following
the parameters specified.
"""
- qry = session.query(TaskInstance).filter(TaskInstance.dag_id ==
self.dag_id)
+ qry = select(TaskInstance).where(TaskInstance.dag_id == self.dag_id)
if start_date:
- qry = qry.filter(TaskInstance.execution_date >= start_date)
+ qry = qry.where(TaskInstance.execution_date >= start_date)
if end_date:
- qry = qry.filter(TaskInstance.execution_date <= end_date)
+ qry = qry.where(TaskInstance.execution_date <= end_date)
tasks = [self.task_id]
@@ -1271,8 +1272,8 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
if downstream:
tasks += [t.task_id for t in
self.get_flat_relatives(upstream=False)]
- qry = qry.filter(TaskInstance.task_id.in_(tasks))
- results = qry.all()
+ qry = qry.where(TaskInstance.task_id.in_(tasks))
+ results = session.scalars(qry).all()
count = len(results)
clear_task_instances(results, session, dag=self.dag)
session.commit()
@@ -1289,16 +1290,15 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
from airflow.models import DagRun
end_date = end_date or timezone.utcnow()
- return (
- session.query(TaskInstance)
+ return session.scalars(
+ select(TaskInstance)
.join(TaskInstance.dag_run)
- .filter(TaskInstance.dag_id == self.dag_id)
- .filter(TaskInstance.task_id == self.task_id)
- .filter(DagRun.execution_date >= start_date)
- .filter(DagRun.execution_date <= end_date)
+ .where(TaskInstance.dag_id == self.dag_id)
+ .where(TaskInstance.task_id == self.task_id)
+ .where(DagRun.execution_date >= start_date)
+ .where(DagRun.execution_date <= end_date)
.order_by(DagRun.execution_date)
- .all()
- )
+ ).all()
@provide_session
def run(
@@ -1327,14 +1327,12 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
for info in self.dag.iter_dagrun_infos_between(start_date, end_date,
align=False):
ignore_depends_on_past = info.logical_date == start_date and
ignore_first_depends_on_past
try:
- dag_run = (
- session.query(DagRun)
- .filter(
+ dag_run = session.scalars(
+ select(DagRun).where(
DagRun.dag_id == self.dag_id,
DagRun.execution_date == info.logical_date,
)
- .one()
- )
+ ).one()
ti = TaskInstance(self, run_id=dag_run.run_id)
except NoResultFound:
# This is _mostly_ only used in tests
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 7e2f23f47b..892ff5cef4 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -55,12 +55,27 @@ import pendulum
import re2 as re
from dateutil.relativedelta import relativedelta
from pendulum.tz.timezone import Timezone
-from sqlalchemy import Boolean, Column, ForeignKey, Index, Integer, String,
Text, and_, case, func, not_, or_
+from sqlalchemy import (
+ Boolean,
+ Column,
+ ForeignKey,
+ Index,
+ Integer,
+ String,
+ Text,
+ and_,
+ case,
+ func,
+ not_,
+ or_,
+ select,
+ update,
+)
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import backref, joinedload, relationship
from sqlalchemy.orm.query import Query
from sqlalchemy.orm.session import Session
-from sqlalchemy.sql import expression
+from sqlalchemy.sql import Select, expression
import airflow.templates
from airflow import settings, utils
@@ -202,11 +217,11 @@ def get_last_dagrun(dag_id, session,
include_externally_triggered=False):
Overridden DagRuns are ignored.
"""
DR = DagRun
- query = session.query(DR).filter(DR.dag_id == dag_id)
+ query = select(DR).where(DR.dag_id == dag_id)
if not include_externally_triggered:
- query = query.filter(DR.external_trigger == expression.false())
+ query = query.where(DR.external_trigger == expression.false())
query = query.order_by(DR.execution_date.desc())
- return query.first()
+ return session.scalar(query.limit(1))
def get_dataset_triggered_next_run_info(
@@ -224,31 +239,27 @@ def get_dataset_triggered_next_run_info(
"ready": x.ready,
"total": x.total,
}
- for x in session.query(
- DagScheduleDatasetReference.dag_id,
- # This is a dirty hack to workaround group by requiring an
aggregate, since grouping by dataset
- # is not what we want to do here...but it works
- case((func.count() == 1, func.max(DatasetModel.uri)),
else_="").label("uri"),
- func.count().label("total"),
- func.sum(case((DDRQ.target_dag_id.is_not(None), 1),
else_=0)).label("ready"),
- )
- .join(
- DDRQ,
- and_(
- DDRQ.dataset_id == DagScheduleDatasetReference.dataset_id,
- DDRQ.target_dag_id == DagScheduleDatasetReference.dag_id,
- ),
- isouter=True,
- )
- .join(
- DatasetModel,
- DatasetModel.id == DagScheduleDatasetReference.dataset_id,
- )
- .group_by(
- DagScheduleDatasetReference.dag_id,
- )
- .filter(DagScheduleDatasetReference.dag_id.in_(dag_ids))
- .all()
+ for x in session.execute(
+ select(
+ DagScheduleDatasetReference.dag_id,
+ # This is a dirty hack to workaround group by requiring an
aggregate,
+ # since grouping by dataset is not what we want to do
here...but it works
+ case((func.count() == 1, func.max(DatasetModel.uri)),
else_="").label("uri"),
+ func.count().label("total"),
+ func.sum(case((DDRQ.target_dag_id.is_not(None), 1),
else_=0)).label("ready"),
+ )
+ .join(
+ DDRQ,
+ and_(
+ DDRQ.dataset_id == DagScheduleDatasetReference.dataset_id,
+ DDRQ.target_dag_id == DagScheduleDatasetReference.dag_id,
+ ),
+ isouter=True,
+ )
+ .join(DatasetModel, DatasetModel.id ==
DagScheduleDatasetReference.dataset_id)
+ .group_by(DagScheduleDatasetReference.dag_id)
+ .where(DagScheduleDatasetReference.dag_id.in_(dag_ids))
+ ).all()
}
@@ -1296,11 +1307,13 @@ class DAG(LoggingMixin):
has been reached.
"""
TI = TaskInstance
- qry = session.query(func.count(TI.task_id)).filter(
- TI.dag_id == self.dag_id,
- TI.state == State.RUNNING,
+ total_tasks = session.scalar(
+ select(func.count(TI.task_id)).where(
+ TI.dag_id == self.dag_id,
+ TI.state == TaskInstanceState.RUNNING,
+ )
)
- return qry.scalar() >= self.max_active_tasks
+ return total_tasks >= self.max_active_tasks
@property
def concurrency_reached(self):
@@ -1315,12 +1328,12 @@ class DAG(LoggingMixin):
@provide_session
def get_is_active(self, session=NEW_SESSION) -> None:
"""Returns a boolean indicating whether this DAG is active."""
- return session.query(DagModel.is_active).filter(DagModel.dag_id ==
self.dag_id).scalar()
+ return session.scalar(select(DagModel.is_active).where(DagModel.dag_id
== self.dag_id))
@provide_session
def get_is_paused(self, session=NEW_SESSION) -> None:
"""Returns a boolean indicating whether this DAG is paused."""
- return session.query(DagModel.is_paused).filter(DagModel.dag_id ==
self.dag_id).scalar()
+ return session.scalar(select(DagModel.is_paused).where(DagModel.dag_id
== self.dag_id))
@property
def is_paused(self):
@@ -1402,19 +1415,18 @@ class DAG(LoggingMixin):
:param session:
:return: number greater than 0 for active dag runs
"""
- # .count() is inefficient
- query = session.query(func.count()).filter(DagRun.dag_id ==
self.dag_id)
+ query = select(func.count()).where(DagRun.dag_id == self.dag_id)
if only_running:
- query = query.filter(DagRun.state == State.RUNNING)
+ query = query.where(DagRun.state == DagRunState.RUNNING)
else:
- query = query.filter(DagRun.state.in_({State.RUNNING,
State.QUEUED}))
+ query = query.where(DagRun.state.in_({DagRunState.RUNNING,
DagRunState.QUEUED}))
if external_trigger is not None:
- query = query.filter(
+ query = query.where(
DagRun.external_trigger == (expression.true() if
external_trigger else expression.false())
)
- return query.scalar()
+ return session.scalar(query)
@provide_session
def get_dagrun(
@@ -1434,12 +1446,12 @@ class DAG(LoggingMixin):
"""
if not (execution_date or run_id):
raise TypeError("You must provide either the execution_date or the
run_id")
- query = session.query(DagRun)
+ query = select(DagRun)
if execution_date:
- query = query.filter(DagRun.dag_id == self.dag_id,
DagRun.execution_date == execution_date)
+ query = query.where(DagRun.dag_id == self.dag_id,
DagRun.execution_date == execution_date)
if run_id:
- query = query.filter(DagRun.dag_id == self.dag_id, DagRun.run_id
== run_id)
- return query.first()
+ query = query.where(DagRun.dag_id == self.dag_id, DagRun.run_id ==
run_id)
+ return session.scalar(query)
@provide_session
def get_dagruns_between(self, start_date, end_date, session=NEW_SESSION):
@@ -1451,22 +1463,20 @@ class DAG(LoggingMixin):
:param session:
:return: The list of DagRuns found.
"""
- dagruns = (
- session.query(DagRun)
- .filter(
+ dagruns = session.scalars(
+ select(DagRun).where(
DagRun.dag_id == self.dag_id,
DagRun.execution_date >= start_date,
DagRun.execution_date <= end_date,
)
- .all()
- )
+ ).all()
return dagruns
@provide_session
def get_latest_execution_date(self, session: Session = NEW_SESSION) ->
pendulum.DateTime | None:
"""Returns the latest date for which at least one dag run exists."""
- return
session.query(func.max(DagRun.execution_date)).filter(DagRun.dag_id ==
self.dag_id).scalar()
+ return
session.scalar(select(func.max(DagRun.execution_date)).where(DagRun.dag_id ==
self.dag_id))
@property
def latest_execution_date(self):
@@ -1553,16 +1563,15 @@ class DAG(LoggingMixin):
corresponding to any DagRunType. It can have less if there are
less than ``num`` scheduled DAG runs before ``base_date``.
"""
- execution_dates: list[Any] = (
- session.query(DagRun.execution_date)
- .filter(
+ execution_dates: list[Any] = session.execute(
+ select(DagRun.execution_date)
+ .where(
DagRun.dag_id == self.dag_id,
DagRun.execution_date <= base_date,
)
.order_by(DagRun.execution_date.desc())
.limit(num)
- .all()
- )
+ ).all()
if len(execution_dates) == 0:
return self.get_task_instances(start_date=base_date,
end_date=base_date, session=session)
@@ -1598,7 +1607,7 @@ class DAG(LoggingMixin):
exclude_task_ids=(),
session=session,
)
- return cast(Query, query).order_by(DagRun.execution_date).all()
+ return session.scalars(cast(Select,
query).order_by(DagRun.execution_date)).all()
@overload
def _get_task_instances(
@@ -1671,9 +1680,9 @@ class DAG(LoggingMixin):
# Do we want full objects, or just the primary columns?
if as_pk_tuple:
- tis = session.query(TI.dag_id, TI.task_id, TI.run_id, TI.map_index)
+ tis = select(TI.dag_id, TI.task_id, TI.run_id, TI.map_index)
else:
- tis = session.query(TaskInstance)
+ tis = select(TaskInstance)
tis = tis.join(TaskInstance.dag_run)
if include_subdags:
@@ -1683,40 +1692,40 @@ class DAG(LoggingMixin):
conditions.append(
(TaskInstance.dag_id == dag.dag_id) &
TaskInstance.task_id.in_(dag.task_ids)
)
- tis = tis.filter(or_(*conditions))
+ tis = tis.where(or_(*conditions))
elif self.partial:
- tis = tis.filter(TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id.in_(self.task_ids))
+ tis = tis.where(TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id.in_(self.task_ids))
else:
- tis = tis.filter(TaskInstance.dag_id == self.dag_id)
+ tis = tis.where(TaskInstance.dag_id == self.dag_id)
if run_id:
- tis = tis.filter(TaskInstance.run_id == run_id)
+ tis = tis.where(TaskInstance.run_id == run_id)
if start_date:
- tis = tis.filter(DagRun.execution_date >= start_date)
+ tis = tis.where(DagRun.execution_date >= start_date)
if task_ids is not None:
- tis = tis.filter(TaskInstance.ti_selector_condition(task_ids))
+ tis = tis.where(TaskInstance.ti_selector_condition(task_ids))
# This allows allow_trigger_in_future config to take affect, rather
than mandating exec_date <= UTC
if end_date or not self.allow_future_exec_dates:
end_date = end_date or timezone.utcnow()
- tis = tis.filter(DagRun.execution_date <= end_date)
+ tis = tis.where(DagRun.execution_date <= end_date)
if state:
if isinstance(state, (str, TaskInstanceState)):
- tis = tis.filter(TaskInstance.state == state)
+ tis = tis.where(TaskInstance.state == state)
elif len(state) == 1:
- tis = tis.filter(TaskInstance.state == state[0])
+ tis = tis.where(TaskInstance.state == state[0])
else:
# this is required to deal with NULL values
if None in state:
if all(x is None for x in state):
- tis = tis.filter(TaskInstance.state.is_(None))
+ tis = tis.where(TaskInstance.state.is_(None))
else:
not_none_state = [s for s in state if s]
- tis = tis.filter(
+ tis = tis.where(
or_(TaskInstance.state.in_(not_none_state),
TaskInstance.state.is_(None))
)
else:
- tis = tis.filter(TaskInstance.state.in_(state))
+ tis = tis.where(TaskInstance.state.in_(state))
# Next, get any of them from our parent DAG (if there is one)
if include_parentdag and self.parent_dag is not None:
@@ -1754,14 +1763,17 @@ class DAG(LoggingMixin):
query = tis
if as_pk_tuple:
- condition = TI.filter_for_tis(TaskInstanceKey(*cols) for cols
in tis.all())
+ all_tis = session.execute(query).all()
+ condition = TI.filter_for_tis(TaskInstanceKey(*cols) for cols
in all_tis)
if condition is not None:
- query = session.query(TI).filter(condition)
+ query = select(TI).where(condition)
if visited_external_tis is None:
visited_external_tis = set()
- for ti in query.filter(TI.operator == ExternalTaskMarker.__name__):
+ external_tasks = session.scalars(query.where(TI.operator ==
ExternalTaskMarker.__name__))
+
+ for ti in external_tasks:
ti_key = ti.key.primary
if ti_key in visited_external_tis:
continue
@@ -1784,10 +1796,10 @@ class DAG(LoggingMixin):
f"Attempted to clear too many tasks or there may be a
cyclic dependency."
)
ti.render_templates()
- external_tis = (
- session.query(TI)
+ external_tis = session.scalars(
+ select(TI)
.join(TI.dag_run)
- .filter(
+ .where(
TI.dag_id == task.external_dag_id,
TI.task_id == task.external_task_id,
DagRun.execution_date ==
pendulum.parse(task.execution_date),
@@ -1830,9 +1842,10 @@ class DAG(LoggingMixin):
if result or as_pk_tuple:
# Only execute the `ti` query if we have also collected some other
results (i.e. subdags etc.)
if as_pk_tuple:
- result.update(TaskInstanceKey(**cols._mapping) for cols in
tis.all())
+ tis_query = session.execute(tis).all()
+ result.update(TaskInstanceKey(**cols._mapping) for cols in
tis_query)
else:
- result.update(ti.key for ti in tis)
+ result.update(ti.key for ti in session.scalars(tis))
if exclude_task_ids is not None:
result = {
@@ -1848,13 +1861,13 @@ class DAG(LoggingMixin):
# We've been asked for objects, lets combine it all back in to a
result set
ti_filters = TI.filter_for_tis(result)
if ti_filters is not None:
- tis = session.query(TI).filter(ti_filters)
+ tis = select(TI).where(ti_filters)
elif exclude_task_ids is None:
pass # Disable filter if not set.
elif isinstance(next(iter(exclude_task_ids), None), str):
- tis = tis.filter(TI.task_id.notin_(exclude_task_ids))
+ tis = tis.where(TI.task_id.notin_(exclude_task_ids))
else:
- tis = tis.filter(not_(tuple_in_condition((TI.task_id,
TI.map_index), exclude_task_ids)))
+ tis = tis.where(not_(tuple_in_condition((TI.task_id,
TI.map_index), exclude_task_ids)))
return tis
@@ -1930,9 +1943,9 @@ class DAG(LoggingMixin):
)
if execution_date is None:
- dag_run = (
- session.query(DagRun).filter(DagRun.run_id == run_id,
DagRun.dag_id == self.dag_id).one()
- ) # Raises an error if not found
+ dag_run = session.scalars(
+ select(DagRun).where(DagRun.run_id == run_id, DagRun.dag_id ==
self.dag_id)
+ ).one() # Raises an error if not found
resolve_execution_date = dag_run.execution_date
else:
resolve_execution_date = execution_date
@@ -1993,9 +2006,9 @@ class DAG(LoggingMixin):
locked_dag_run_ids: list[int] = []
if execution_date is None:
- dag_run = (
- session.query(DagRun).filter(DagRun.run_id == run_id,
DagRun.dag_id == self.dag_id).one()
- ) # Raises an error if not found
+ dag_run = session.scalars(
+ select(DagRun).where(DagRun.run_id == run_id, DagRun.dag_id ==
self.dag_id)
+ ).one() # Raises an error if not found
resolve_execution_date = dag_run.execution_date
else:
resolve_execution_date = execution_date
@@ -2009,16 +2022,16 @@ class DAG(LoggingMixin):
raise ValueError("TaskGroup {group_id} could not be found")
tasks_to_set_state = [task for task in task_group.iter_tasks() if
isinstance(task, BaseOperator)]
task_ids = [task.task_id for task in task_group.iter_tasks()]
- dag_runs_query = session.query(DagRun.id).filter(DagRun.dag_id ==
self.dag_id).with_for_update()
+ dag_runs_query = session.query(DagRun.id).where(DagRun.dag_id ==
self.dag_id).with_for_update()
if start_date is None and end_date is None:
- dag_runs_query = dag_runs_query.filter(DagRun.execution_date ==
start_date)
+ dag_runs_query = dag_runs_query.where(DagRun.execution_date ==
start_date)
else:
if start_date is not None:
- dag_runs_query = dag_runs_query.filter(DagRun.execution_date
>= start_date)
+ dag_runs_query = dag_runs_query.where(DagRun.execution_date >=
start_date)
if end_date is not None:
- dag_runs_query = dag_runs_query.filter(DagRun.execution_date
<= end_date)
+ dag_runs_query = dag_runs_query.where(DagRun.execution_date <=
end_date)
locked_dag_run_ids = dag_runs_query.all()
@@ -2105,12 +2118,12 @@ class DAG(LoggingMixin):
stacklevel=3,
)
dag_ids = dag_ids or [self.dag_id]
- query = session.query(DagRun).filter(DagRun.dag_id.in_(dag_ids))
+ query = update(DagRun).where(DagRun.dag_id.in_(dag_ids))
if start_date:
- query = query.filter(DagRun.execution_date >= start_date)
+ query = query.where(DagRun.execution_date >= start_date)
if end_date:
- query = query.filter(DagRun.execution_date <= end_date)
- query.update({DagRun.state: state}, synchronize_session="fetch")
+ query = query.where(DagRun.execution_date <= end_date)
+
session.execute(query.values(state=state).execution_options(synchronize_session="fetch"))
@provide_session
def clear(
@@ -2196,11 +2209,11 @@ class DAG(LoggingMixin):
)
if dry_run:
- return tis
+ return session.scalars(tis).all()
- tis = list(tis)
+ tis = session.scalars(tis).all()
- count = len(tis)
+ count = len(list(tis))
do_it = True
if count == 0:
return 0
@@ -2213,7 +2226,7 @@ class DAG(LoggingMixin):
if do_it:
clear_task_instances(
- tis,
+ list(tis),
session,
dag=self,
dag_run_state=dag_run_state,
@@ -2462,10 +2475,10 @@ class DAG(LoggingMixin):
@provide_session
def pickle(self, session=NEW_SESSION) -> DagPickle:
- dag = session.query(DagModel).filter(DagModel.dag_id ==
self.dag_id).first()
+ dag = session.scalar(select(DagModel).where(DagModel.dag_id ==
self.dag_id).limit(1))
dp = None
if dag and dag.pickle_id:
- dp = session.query(DagPickle).filter(DagPickle.id ==
dag.pickle_id).first()
+ dp = session.scalar(select(DagPickle).where(DagPickle.id ==
dag.pickle_id).limit(1))
if not dp or dp.pickle != self:
dp = DagPickle(dag=self)
session.add(dp)
@@ -2881,13 +2894,14 @@ class DAG(LoggingMixin):
dag_ids = set(dag_by_ids.keys())
query = (
- session.query(DagModel)
+ select(DagModel)
.options(joinedload(DagModel.tags, innerjoin=False))
- .filter(DagModel.dag_id.in_(dag_ids))
+ .where(DagModel.dag_id.in_(dag_ids))
.options(joinedload(DagModel.schedule_dataset_references))
.options(joinedload(DagModel.task_outlet_dataset_references))
)
- orm_dags: list[DagModel] = with_row_locks(query, of=DagModel,
session=session).all()
+ query = with_row_locks(query, of=DagModel, session=session)
+ orm_dags: list[DagModel] = session.scalars(query).unique().all()
existing_dags = {orm_dag.dag_id: orm_dag for orm_dag in orm_dags}
missing_dag_ids = dag_ids.difference(existing_dags)
@@ -2903,17 +2917,19 @@ class DAG(LoggingMixin):
# Get the latest dag run for each existing dag as a single query
(avoid n+1 query)
most_recent_subq = (
- session.query(DagRun.dag_id,
func.max(DagRun.execution_date).label("max_execution_date"))
- .filter(
+ select(DagRun.dag_id,
func.max(DagRun.execution_date).label("max_execution_date"))
+ .where(
DagRun.dag_id.in_(existing_dags),
or_(DagRun.run_type == DagRunType.BACKFILL_JOB,
DagRun.run_type == DagRunType.SCHEDULED),
)
.group_by(DagRun.dag_id)
.subquery()
)
- most_recent_runs_iter = session.query(DagRun).filter(
- DagRun.dag_id == most_recent_subq.c.dag_id,
- DagRun.execution_date == most_recent_subq.c.max_execution_date,
+ most_recent_runs_iter = session.scalars(
+ select(DagRun).where(
+ DagRun.dag_id == most_recent_subq.c.dag_id,
+ DagRun.execution_date == most_recent_subq.c.max_execution_date,
+ )
)
most_recent_runs = {run.dag_id: run for run in most_recent_runs_iter}
@@ -3029,7 +3045,9 @@ class DAG(LoggingMixin):
# store datasets
stored_datasets = {}
for dataset in all_datasets:
- stored_dataset =
session.query(DatasetModel).filter(DatasetModel.uri == dataset.uri).first()
+ stored_dataset = session.scalar(
+ select(DatasetModel).where(DatasetModel.uri ==
dataset.uri).limit(1)
+ )
if stored_dataset:
# Some datasets may have been previously unreferenced, and
therefore orphaned by the
# scheduler. But if we're here, then we have found that
dataset again in our DAGs, which
@@ -3114,7 +3132,7 @@ class DAG(LoggingMixin):
"""
if len(active_dag_ids) == 0:
return
- for dag in
session.query(DagModel).filter(~DagModel.dag_id.in_(active_dag_ids)).all():
+ for dag in
session.scalars(select(DagModel).where(~DagModel.dag_id.in_(active_dag_ids))).all():
dag.is_active = False
session.merge(dag)
session.commit()
@@ -3130,10 +3148,8 @@ class DAG(LoggingMixin):
time
:return: None
"""
- for dag in (
- session.query(DagModel)
- .filter(DagModel.last_parsed_time < expiration_date,
DagModel.is_active)
- .all()
+ for dag in session.scalars(
+ select(DagModel).where(DagModel.last_parsed_time <
expiration_date, DagModel.is_active)
):
log.info(
"Deactivating DAG ID %s since it was last touched by the
scheduler at %s",
@@ -3157,30 +3173,30 @@ class DAG(LoggingMixin):
:param states: A list of states to filter by if supplied
:return: The number of running tasks
"""
- qry = session.query(func.count(TaskInstance.task_id)).filter(
+ qry = select(func.count(TaskInstance.task_id)).where(
TaskInstance.dag_id == dag_id,
)
if run_id:
- qry = qry.filter(
+ qry = qry.where(
TaskInstance.run_id == run_id,
)
if task_ids:
- qry = qry.filter(
+ qry = qry.where(
TaskInstance.task_id.in_(task_ids),
)
if states:
if None in states:
if all(x is None for x in states):
- qry = qry.filter(TaskInstance.state.is_(None))
+ qry = qry.where(TaskInstance.state.is_(None))
else:
not_none_states = [state for state in states if state]
- qry = qry.filter(
+ qry = qry.where(
or_(TaskInstance.state.in_(not_none_states),
TaskInstance.state.is_(None))
)
else:
- qry = qry.filter(TaskInstance.state.in_(states))
- return qry.scalar()
+ qry = qry.where(TaskInstance.state.in_(states))
+ return session.scalar(qry)
@classmethod
def get_serialized_fields(cls):
@@ -3301,7 +3317,7 @@ class DagOwnerAttributes(Base):
@classmethod
def get_all(cls, session) -> dict[str, dict[str, str]]:
dag_links: dict = collections.defaultdict(dict)
- for obj in session.query(cls):
+ for obj in session.scalars(select(cls)):
dag_links[obj.dag_id].update({obj.owner: obj.link})
return dag_links
@@ -3447,7 +3463,7 @@ class DagModel(Base):
@classmethod
@provide_session
def get_current(cls, dag_id, session=NEW_SESSION):
- return session.query(cls).filter(cls.dag_id == dag_id).first()
+ return session.scalar(select(cls).where(cls.dag_id == dag_id))
@provide_session
def get_last_dagrun(self, session=NEW_SESSION,
include_externally_triggered=False):
@@ -3470,11 +3486,10 @@ class DagModel(Base):
:param session: ORM Session
:return: Paused Dag_ids
"""
- paused_dag_ids = (
- session.query(DagModel.dag_id)
- .filter(DagModel.is_paused == expression.true())
- .filter(DagModel.dag_id.in_(dag_ids))
- .all()
+ paused_dag_ids = session.execute(
+ select(DagModel.dag_id)
+ .where(DagModel.is_paused == expression.true())
+ .where(DagModel.dag_id.in_(dag_ids))
)
paused_dag_ids = {paused_dag_id for paused_dag_id, in paused_dag_ids}
@@ -3518,8 +3533,11 @@ class DagModel(Base):
]
if including_subdags:
filter_query.append(DagModel.root_dag_id == self.dag_id)
- session.query(DagModel).filter(or_(*filter_query)).update(
- {DagModel.is_paused: is_paused}, synchronize_session="fetch"
+ session.execute(
+ update(DagModel)
+ .where(or_(*filter_query))
+ .values(is_paused=is_paused)
+ .execution_options(synchronize_session="fetch")
)
session.commit()
@@ -3538,7 +3556,8 @@ class DagModel(Base):
:param session: ORM Session
"""
log.debug("Deactivating DAGs (for which DAG files are deleted) from %s
table ", cls.__tablename__)
- for dag_model in session.query(cls).filter(cls.fileloc.is_not(None)):
+ dag_models =
session.scalars(select(cls).where(cls.fileloc.is_not(None)))
+ for dag_model in dag_models:
if dag_model.fileloc not in alive_dag_filelocs:
dag_model.is_active = False
@@ -3556,28 +3575,30 @@ class DagModel(Base):
# these dag ids are triggered by datasets, and they are ready to go.
dataset_triggered_dag_info = {
x.dag_id: (x.first_queued_time, x.last_queued_time)
- for x in session.query(
- DagScheduleDatasetReference.dag_id,
- func.max(DDRQ.created_at).label("last_queued_time"),
- func.min(DDRQ.created_at).label("first_queued_time"),
+ for x in session.execute(
+ select(
+ DagScheduleDatasetReference.dag_id,
+ func.max(DDRQ.created_at).label("last_queued_time"),
+ func.min(DDRQ.created_at).label("first_queued_time"),
+ )
+ .join(DagScheduleDatasetReference.queue_records, isouter=True)
+ .group_by(DagScheduleDatasetReference.dag_id)
+ .having(func.count() ==
func.sum(case((DDRQ.target_dag_id.is_not(None), 1), else_=0)))
)
- .join(DagScheduleDatasetReference.queue_records, isouter=True)
- .group_by(DagScheduleDatasetReference.dag_id)
- .having(func.count() ==
func.sum(case((DDRQ.target_dag_id.is_not(None), 1), else_=0)))
- .all()
}
dataset_triggered_dag_ids = set(dataset_triggered_dag_info.keys())
if dataset_triggered_dag_ids:
exclusion_list = {
- x.dag_id
+ x
for x in (
- session.query(DagModel.dag_id)
- .join(DagRun.dag_model)
- .filter(DagRun.state.in_((DagRunState.QUEUED,
DagRunState.RUNNING)))
- .filter(DagModel.dag_id.in_(dataset_triggered_dag_ids))
- .group_by(DagModel.dag_id)
- .having(func.count() >= func.max(DagModel.max_active_runs))
- .all()
+ session.scalars(
+ select(DagModel.dag_id)
+ .join(DagRun.dag_model)
+ .where(DagRun.state.in_((DagRunState.QUEUED,
DagRunState.RUNNING)))
+ .where(DagModel.dag_id.in_(dataset_triggered_dag_ids))
+ .group_by(DagModel.dag_id)
+ .having(func.count() >=
func.max(DagModel.max_active_runs))
+ )
)
}
if exclusion_list:
@@ -3588,8 +3609,8 @@ class DagModel(Base):
# We limit so that _one_ scheduler doesn't try to do all the creation
of dag runs
query = (
- session.query(cls)
- .filter(
+ select(cls)
+ .where(
cls.is_paused == expression.false(),
cls.is_active == expression.true(),
cls.has_import_errors == expression.false(),
@@ -3603,7 +3624,7 @@ class DagModel(Base):
)
return (
- with_row_locks(query, of=cls, session=session,
**skip_locked(session=session)),
+ session.scalars(with_row_locks(query, of=cls, session=session,
**skip_locked(session=session))),
dataset_triggered_dag_info,
)
@@ -3871,10 +3892,8 @@ def _get_or_create_dagrun(
:return: The newly created DAG run.
"""
log.info("dagrun id: %s", dag.dag_id)
- dr: DagRun = (
- session.query(DagRun)
- .filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date ==
execution_date)
- .first()
+ dr: DagRun = session.scalar(
+ select(DagRun).where(DagRun.dag_id == dag.dag_id,
DagRun.execution_date == execution_date)
)
if dr:
session.delete(dr)
diff --git a/airflow/models/dagcode.py b/airflow/models/dagcode.py
index 9849f328f2..61b007bf6c 100644
--- a/airflow/models/dagcode.py
+++ b/airflow/models/dagcode.py
@@ -22,7 +22,7 @@ import struct
from datetime import datetime
from typing import Collection, Iterable
-from sqlalchemy import BigInteger, Column, String, Text, delete
+from sqlalchemy import BigInteger, Column, String, Text, delete, select
from sqlalchemy.dialects.mysql import MEDIUMTEXT
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import literal
@@ -77,12 +77,11 @@ class DagCode(Base):
"""
filelocs = set(filelocs)
filelocs_to_hashes = {fileloc: DagCode.dag_fileloc_hash(fileloc) for
fileloc in filelocs}
- existing_orm_dag_codes = (
- session.query(DagCode)
+ existing_orm_dag_codes = session.scalars(
+ select(DagCode)
.filter(DagCode.fileloc_hash.in_(filelocs_to_hashes.values()))
.with_for_update(of=DagCode)
- .all()
- )
+ ).all()
if existing_orm_dag_codes:
existing_orm_dag_codes_map = {
@@ -151,7 +150,10 @@ class DagCode(Base):
:param session: ORM Session
"""
fileloc_hash = cls.dag_fileloc_hash(fileloc)
- return session.query(literal(True)).filter(cls.fileloc_hash ==
fileloc_hash).one_or_none() is not None
+ return (
+ session.scalars(select(literal(True)).where(cls.fileloc_hash ==
fileloc_hash)).one_or_none()
+ is not None
+ )
@classmethod
def get_code_by_fileloc(cls, fileloc: str) -> str:
@@ -179,7 +181,7 @@ class DagCode(Base):
@classmethod
@provide_session
def _get_code_from_db(cls, fileloc, session: Session = NEW_SESSION) -> str:
- dag_code = session.query(cls).filter(cls.fileloc_hash ==
cls.dag_fileloc_hash(fileloc)).first()
+ dag_code = session.scalar(select(cls).where(cls.fileloc_hash ==
cls.dag_fileloc_hash(fileloc)))
if not dag_code:
raise DagCodeNotFound()
else:
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 538cceede8..dab89c12cd 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -40,6 +40,7 @@ from sqlalchemy import (
func,
or_,
text,
+ update,
)
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.associationproxy import association_proxy
@@ -270,7 +271,9 @@ class DagRun(Base, LoggingMixin):
:param session: database session
"""
- dr = session.query(DagRun).filter(DagRun.dag_id == self.dag_id,
DagRun.run_id == self.run_id).one()
+ dr = session.scalars(
+ select(DagRun).where(DagRun.dag_id == self.dag_id, DagRun.run_id
== self.run_id)
+ ).one()
self.id = dr.id
self.state = dr.state
@@ -283,17 +286,17 @@ class DagRun(Base, LoggingMixin):
session: Session = NEW_SESSION,
) -> dict[str, int]:
"""Get the number of active dag runs for each dag."""
- query = session.query(cls.dag_id, func.count("*"))
+ query = select(cls.dag_id, func.count("*"))
if dag_ids is not None:
# 'set' called to avoid duplicate dag_ids, but converted back to
'list'
# because SQLAlchemy doesn't accept a set here.
- query = query.filter(cls.dag_id.in_(set(dag_ids)))
+ query = query.where(cls.dag_id.in_(set(dag_ids)))
if only_running:
- query = query.filter(cls.state == State.RUNNING)
+ query = query.where(cls.state == State.RUNNING)
else:
- query = query.filter(cls.state.in_([State.RUNNING, State.QUEUED]))
+ query = query.where(cls.state.in_([State.RUNNING, State.QUEUED]))
query = query.group_by(cls.dag_id)
- return {dag_id: count for dag_id, count in query.all()}
+ return {dag_id: count for dag_id, count in session.execute(query)}
@classmethod
def next_dagruns_to_examine(
@@ -317,22 +320,22 @@ class DagRun(Base, LoggingMixin):
# TODO: Bake this query, it is run _A lot_
query = (
- session.query(cls)
+ select(cls)
.with_hint(cls, "USE INDEX (idx_dag_run_running_dags)",
dialect_name="mysql")
- .filter(cls.state == state, cls.run_type !=
DagRunType.BACKFILL_JOB)
+ .where(cls.state == state, cls.run_type != DagRunType.BACKFILL_JOB)
.join(DagModel, DagModel.dag_id == cls.dag_id)
- .filter(DagModel.is_paused == false(), DagModel.is_active ==
true())
+ .where(DagModel.is_paused == false(), DagModel.is_active == true())
)
if state == State.QUEUED:
# For dag runs in the queued state, we check if they have reached
the max_active_runs limit
# and if so we drop them
running_drs = (
- session.query(DagRun.dag_id,
func.count(DagRun.state).label("num_running"))
- .filter(DagRun.state == DagRunState.RUNNING)
+ select(DagRun.dag_id,
func.count(DagRun.state).label("num_running"))
+ .where(DagRun.state == DagRunState.RUNNING)
.group_by(DagRun.dag_id)
.subquery()
)
- query = query.outerjoin(running_drs, running_drs.c.dag_id ==
DagRun.dag_id).filter(
+ query = query.outerjoin(running_drs, running_drs.c.dag_id ==
DagRun.dag_id).where(
func.coalesce(running_drs.c.num_running, 0) <
DagModel.max_active_runs
)
query = query.order_by(
@@ -341,10 +344,10 @@ class DagRun(Base, LoggingMixin):
)
if not settings.ALLOW_FUTURE_EXEC_DATES:
- query = query.filter(DagRun.execution_date <= func.now())
+ query = query.where(DagRun.execution_date <= func.now())
- return with_row_locks(
- query.limit(max_number), of=cls, session=session,
**skip_locked(session=session)
+ return session.scalars(
+ with_row_locks(query.limit(max_number), of=cls, session=session,
**skip_locked(session=session))
)
@classmethod
@@ -377,35 +380,35 @@ class DagRun(Base, LoggingMixin):
:param execution_start_date: dag run that was executed from this date
:param execution_end_date: dag run that was executed until this date
"""
- qry = session.query(cls)
+ qry = select(cls)
dag_ids = [dag_id] if isinstance(dag_id, str) else dag_id
if dag_ids:
- qry = qry.filter(cls.dag_id.in_(dag_ids))
+ qry = qry.where(cls.dag_id.in_(dag_ids))
if is_container(run_id):
- qry = qry.filter(cls.run_id.in_(run_id))
+ qry = qry.where(cls.run_id.in_(run_id))
elif run_id is not None:
- qry = qry.filter(cls.run_id == run_id)
+ qry = qry.where(cls.run_id == run_id)
if is_container(execution_date):
- qry = qry.filter(cls.execution_date.in_(execution_date))
+ qry = qry.where(cls.execution_date.in_(execution_date))
elif execution_date is not None:
- qry = qry.filter(cls.execution_date == execution_date)
+ qry = qry.where(cls.execution_date == execution_date)
if execution_start_date and execution_end_date:
- qry = qry.filter(cls.execution_date.between(execution_start_date,
execution_end_date))
+ qry = qry.where(cls.execution_date.between(execution_start_date,
execution_end_date))
elif execution_start_date:
- qry = qry.filter(cls.execution_date >= execution_start_date)
+ qry = qry.where(cls.execution_date >= execution_start_date)
elif execution_end_date:
- qry = qry.filter(cls.execution_date <= execution_end_date)
+ qry = qry.where(cls.execution_date <= execution_end_date)
if state:
- qry = qry.filter(cls.state == state)
+ qry = qry.where(cls.state == state)
if external_trigger is not None:
- qry = qry.filter(cls.external_trigger == external_trigger)
+ qry = qry.where(cls.external_trigger == external_trigger)
if run_type:
- qry = qry.filter(cls.run_type == run_type)
+ qry = qry.where(cls.run_type == run_type)
if no_backfills:
- qry = qry.filter(cls.run_type != DagRunType.BACKFILL_JOB)
+ qry = qry.where(cls.run_type != DagRunType.BACKFILL_JOB)
- return qry.order_by(cls.execution_date).all()
+ return session.scalars(qry.order_by(cls.execution_date)).all()
@classmethod
@provide_session
@@ -426,14 +429,12 @@ class DagRun(Base, LoggingMixin):
:param execution_date: the execution date
:param session: database session
"""
- return (
- session.query(cls)
- .filter(
+ return session.scalars(
+ select(cls).where(
cls.dag_id == dag_id,
or_(cls.run_id == run_id, cls.execution_date ==
execution_date),
)
- .one_or_none()
- )
+ ).one_or_none()
@staticmethod
def generate_run_id(run_type: DagRunType, execution_date: datetime) -> str:
@@ -449,9 +450,9 @@ class DagRun(Base, LoggingMixin):
) -> list[TI]:
"""Returns the task instances for this dag run."""
tis = (
- session.query(TI)
+ select(TI)
.options(joinedload(TI.dag_run))
- .filter(
+ .where(
TI.dag_id == self.dag_id,
TI.run_id == self.run_id,
)
@@ -459,21 +460,21 @@ class DagRun(Base, LoggingMixin):
if state:
if isinstance(state, str):
- tis = tis.filter(TI.state == state)
+ tis = tis.where(TI.state == state)
else:
# this is required to deal with NULL values
if State.NONE in state:
if all(x is None for x in state):
- tis = tis.filter(TI.state.is_(None))
+ tis = tis.where(TI.state.is_(None))
else:
not_none_state = [s for s in state if s]
- tis = tis.filter(or_(TI.state.in_(not_none_state),
TI.state.is_(None)))
+ tis = tis.where(or_(TI.state.in_(not_none_state),
TI.state.is_(None)))
else:
- tis = tis.filter(TI.state.in_(state))
+ tis = tis.where(TI.state.in_(state))
if self.dag and self.dag.partial:
- tis = tis.filter(TI.task_id.in_(self.dag.task_ids))
- return tis.all()
+ tis = tis.where(TI.task_id.in_(self.dag.task_ids))
+ return session.scalars(tis).all()
@provide_session
def get_task_instance(
@@ -489,11 +490,9 @@ class DagRun(Base, LoggingMixin):
:param task_id: the task id
:param session: Sqlalchemy ORM Session
"""
- return (
- session.query(TI)
- .filter_by(dag_id=self.dag_id, run_id=self.run_id,
task_id=task_id, map_index=map_index)
- .one_or_none()
- )
+ return session.scalars(
+ select(TI).filter_by(dag_id=self.dag_id, run_id=self.run_id,
task_id=task_id, map_index=map_index)
+ ).one_or_none()
def get_dag(self) -> DAG:
"""
@@ -517,20 +516,19 @@ class DagRun(Base, LoggingMixin):
]
if state is not None:
filters.append(DagRun.state == state)
- return
session.query(DagRun).filter(*filters).order_by(DagRun.execution_date.desc()).first()
+ return
session.scalar(select(DagRun).where(*filters).order_by(DagRun.execution_date.desc()))
@provide_session
def get_previous_scheduled_dagrun(self, session: Session = NEW_SESSION) ->
DagRun | None:
"""The previous, SCHEDULED DagRun, if there is one."""
- return (
- session.query(DagRun)
- .filter(
+ return session.scalar(
+ select(DagRun)
+ .where(
DagRun.dag_id == self.dag_id,
DagRun.execution_date < self.execution_date,
DagRun.run_type != DagRunType.MANUAL,
)
.order_by(DagRun.execution_date.desc())
- .first()
)
def _tis_for_dagrun_state(self, *, dag, tis):
@@ -867,7 +865,7 @@ class DagRun(Base, LoggingMixin):
# Check if any ti changed state
tis_filter = TI.filter_for_tis(old_states)
if tis_filter is not None:
- fresh_tis = session.query(TI).filter(tis_filter).all()
+ fresh_tis = session.scalars(select(TI).where(tis_filter)).all()
changed_tis = any(ti.state != old_states[ti.key] for ti in
fresh_tis)
return ready_tis, changed_tis, expansion_happened
@@ -1219,21 +1217,27 @@ class DagRun(Base, LoggingMixin):
except NotFullyPopulated:
return # Upstreams not ready, don't need to revise this yet.
- query = session.query(TI.map_index).filter(
- TI.dag_id == self.dag_id,
- TI.task_id == task.task_id,
- TI.run_id == self.run_id,
+ query = session.scalars(
+ select(TI.map_index).where(
+ TI.dag_id == self.dag_id,
+ TI.task_id == task.task_id,
+ TI.run_id == self.run_id,
+ )
)
- existing_indexes = {i for (i,) in query}
+ existing_indexes = {i for i in query}
removed_indexes = existing_indexes.difference(range(total_length))
if removed_indexes:
- session.query(TI).filter(
- TI.dag_id == self.dag_id,
- TI.task_id == task.task_id,
- TI.run_id == self.run_id,
- TI.map_index.in_(removed_indexes),
- ).update({TI.state: TaskInstanceState.REMOVED})
+ session.execute(
+ update(TI)
+ .where(
+ TI.dag_id == self.dag_id,
+ TI.task_id == task.task_id,
+ TI.run_id == self.run_id,
+ TI.map_index.in_(removed_indexes),
+ )
+ .values(state=TaskInstanceState.REMOVED)
+ )
session.flush()
for index in range(total_length):
@@ -1264,14 +1268,12 @@ class DagRun(Base, LoggingMixin):
RemovedInAirflow3Warning,
stacklevel=2,
)
- return (
- session.query(DagRun)
- .filter(
+ return session.scalar(
+ select(DagRun).where(
DagRun.dag_id == dag_id,
DagRun.external_trigger == False, # noqa
DagRun.execution_date == execution_date,
)
- .first()
)
@property
@@ -1283,18 +1285,16 @@ class DagRun(Base, LoggingMixin):
def get_latest_runs(cls, session: Session = NEW_SESSION) -> list[DagRun]:
"""Returns the latest DagRun for each DAG."""
subquery = (
- session.query(cls.dag_id,
func.max(cls.execution_date).label("execution_date"))
+ select(cls.dag_id,
func.max(cls.execution_date).label("execution_date"))
.group_by(cls.dag_id)
.subquery()
)
- return (
- session.query(cls)
- .join(
+ return session.scalars(
+ select(cls).join(
subquery,
and_(cls.dag_id == subquery.c.dag_id, cls.execution_date ==
subquery.c.execution_date),
)
- .all()
- )
+ ).all()
@provide_session
def schedule_tis(
@@ -1335,44 +1335,45 @@ class DagRun(Base, LoggingMixin):
schedulable_ti_ids, max_tis_per_query or
len(schedulable_ti_ids)
)
for schedulable_ti_ids_chunk in schedulable_ti_ids_chunks:
- count += (
- session.query(TI)
- .filter(
+ count += session.execute(
+ update(TI)
+ .where(
TI.dag_id == self.dag_id,
TI.run_id == self.run_id,
tuple_in_condition((TI.task_id, TI.map_index),
schedulable_ti_ids_chunk),
)
- .update({TI.state: State.SCHEDULED},
synchronize_session=False)
- )
+ .values(state=TaskInstanceState.SCHEDULED)
+ .execution_options(synchronize_session=False)
+ ).rowcount
# Tasks using EmptyOperator should not be executed, mark them as
success
if dummy_ti_ids:
dummy_ti_ids_chunks = chunks(dummy_ti_ids, max_tis_per_query or
len(dummy_ti_ids))
for dummy_ti_ids_chunk in dummy_ti_ids_chunks:
- count += (
- session.query(TI)
- .filter(
+ count += session.execute(
+ update(TI)
+ .where(
TI.dag_id == self.dag_id,
TI.run_id == self.run_id,
TI.task_id.in_(dummy_ti_ids_chunk),
)
- .update(
- {
- TI.state: State.SUCCESS,
- TI.start_date: timezone.utcnow(),
- TI.end_date: timezone.utcnow(),
- TI.duration: 0,
- },
+ .values(
+ state=TaskInstanceState.SUCCESS,
+ start_date=timezone.utcnow(),
+ end_date=timezone.utcnow(),
+ duration=0,
+ )
+ .execution_options(
synchronize_session=False,
)
- )
+ ).rowcount
return count
@provide_session
def get_log_template(self, *, session: Session = NEW_SESSION) ->
LogTemplate:
if self.log_template_id is None: # DagRun created before LogTemplate
introduction.
- template =
session.query(LogTemplate).order_by(LogTemplate.id).first()
+ template =
session.scalar(select(LogTemplate).order_by(LogTemplate.id))
else:
template = session.get(LogTemplate, self.log_template_id)
if template is None:
diff --git a/airflow/models/dagwarning.py b/airflow/models/dagwarning.py
index ddc5441321..2690ba8ff8 100644
--- a/airflow/models/dagwarning.py
+++ b/airflow/models/dagwarning.py
@@ -19,7 +19,7 @@ from __future__ import annotations
from enum import Enum
-from sqlalchemy import Column, ForeignKeyConstraint, String, Text, false
+from sqlalchemy import Column, ForeignKeyConstraint, String, Text, delete,
false, select
from sqlalchemy.orm import Session
from airflow.api_internal.internal_api_call import internal_api_call
@@ -83,11 +83,12 @@ class DagWarning(Base):
from airflow.models.dag import DagModel
if session.get_bind().dialect.name == "sqlite":
- dag_ids = session.query(DagModel.dag_id).filter(DagModel.is_active
== false())
- query = session.query(cls).filter(cls.dag_id.in_(dag_ids))
+ dag_ids_stmt = select(DagModel.dag_id).where(DagModel.is_active ==
false())
+ query =
delete(cls).where(cls.dag_id.in_(dag_ids_stmt.scalar_subquery()))
else:
- query = session.query(cls).filter(cls.dag_id == DagModel.dag_id,
DagModel.is_active == false())
- query.delete(synchronize_session=False)
+ query = delete(cls).where(cls.dag_id == DagModel.dag_id,
DagModel.is_active == false())
+
+ session.execute(query.execution_options(synchronize_session=False))
session.commit()
diff --git a/airflow/models/pool.py b/airflow/models/pool.py
index d1766d4a0a..60f92506f6 100644
--- a/airflow/models/pool.py
+++ b/airflow/models/pool.py
@@ -17,9 +17,9 @@
# under the License.
from __future__ import annotations
-from typing import Any, Iterable
+from typing import Any
-from sqlalchemy import Column, Integer, String, Text, func
+from sqlalchemy import Column, Integer, String, Text, func, select
from sqlalchemy.orm.session import Session
from airflow.exceptions import AirflowException, PoolNotFound
@@ -60,7 +60,7 @@ class Pool(Base):
@provide_session
def get_pools(session: Session = NEW_SESSION) -> list[Pool]:
"""Get all pools."""
- return session.query(Pool).all()
+ return session.scalars(select(Pool)).all()
@staticmethod
@provide_session
@@ -72,7 +72,7 @@ class Pool(Base):
:param session: SQLAlchemy ORM Session
:return: the pool object
"""
- return session.query(Pool).filter(Pool.pool == pool_name).first()
+ return session.scalar(select(Pool).where(Pool.pool == pool_name))
@staticmethod
@provide_session
@@ -96,9 +96,9 @@ class Pool(Base):
:return: True if id is default_pool, otherwise False
"""
return (
- session.query(func.count(Pool.id))
- .filter(Pool.id == id, Pool.pool == Pool.DEFAULT_POOL_NAME)
- .scalar()
+ session.scalar(
+ select(func.count(Pool.id)).where(Pool.id == id, Pool.pool ==
Pool.DEFAULT_POOL_NAME)
+ )
> 0
)
@@ -114,7 +114,7 @@ class Pool(Base):
if not name:
raise ValueError("Pool name must not be empty")
- pool = session.query(Pool).filter_by(pool=name).one_or_none()
+ pool = session.scalar(select(Pool).filter_by(pool=name))
if pool is None:
pool = Pool(pool=name, slots=slots, description=description)
session.add(pool)
@@ -132,7 +132,7 @@ class Pool(Base):
if name == Pool.DEFAULT_POOL_NAME:
raise AirflowException(f"{Pool.DEFAULT_POOL_NAME} cannot be
deleted")
- pool = session.query(Pool).filter_by(pool=name).first()
+ pool = session.scalar(select(Pool).filter_by(pool=name))
if pool is None:
raise PoolNotFound(f"Pool '{name}' doesn't exist")
@@ -162,22 +162,22 @@ class Pool(Base):
pools: dict[str, PoolStats] = {}
- query = session.query(Pool.pool, Pool.slots)
+ query = select(Pool.pool, Pool.slots)
if lock_rows:
query = with_row_locks(query, session=session, **nowait(session))
- pool_rows: Iterable[tuple[str, int]] = query.all()
+ pool_rows = session.execute(query)
for (pool_name, total_slots) in pool_rows:
if total_slots == -1:
total_slots = float("inf") # type: ignore
pools[pool_name] = PoolStats(total=total_slots, running=0,
queued=0, open=0)
- state_count_by_pool = (
- session.query(TaskInstance.pool, TaskInstance.state,
func.sum(TaskInstance.pool_slots))
+ state_count_by_pool = session.execute(
+ select(TaskInstance.pool, TaskInstance.state,
func.sum(TaskInstance.pool_slots))
.filter(TaskInstance.state.in_(list(EXECUTION_STATES)))
.group_by(TaskInstance.pool, TaskInstance.state)
- ).all()
+ )
# calculate queued and running metrics
for (pool_name, state, count) in state_count_by_pool:
@@ -225,10 +225,11 @@ class Pool(Base):
from airflow.models.taskinstance import TaskInstance # Avoid circular
import
return int(
- session.query(func.sum(TaskInstance.pool_slots))
- .filter(TaskInstance.pool == self.pool)
- .filter(TaskInstance.state.in_(EXECUTION_STATES))
- .scalar()
+ session.scalar(
+ select(func.sum(TaskInstance.pool_slots))
+ .filter(TaskInstance.pool == self.pool)
+ .filter(TaskInstance.state.in_(EXECUTION_STATES))
+ )
or 0
)
@@ -243,10 +244,11 @@ class Pool(Base):
from airflow.models.taskinstance import TaskInstance # Avoid circular
import
return int(
- session.query(func.sum(TaskInstance.pool_slots))
- .filter(TaskInstance.pool == self.pool)
- .filter(TaskInstance.state == State.RUNNING)
- .scalar()
+ session.scalar(
+ select(func.sum(TaskInstance.pool_slots))
+ .filter(TaskInstance.pool == self.pool)
+ .filter(TaskInstance.state == State.RUNNING)
+ )
or 0
)
@@ -261,10 +263,11 @@ class Pool(Base):
from airflow.models.taskinstance import TaskInstance # Avoid circular
import
return int(
- session.query(func.sum(TaskInstance.pool_slots))
- .filter(TaskInstance.pool == self.pool)
- .filter(TaskInstance.state == State.QUEUED)
- .scalar()
+ session.scalar(
+ select(func.sum(TaskInstance.pool_slots))
+ .filter(TaskInstance.pool == self.pool)
+ .filter(TaskInstance.state == State.QUEUED)
+ )
or 0
)
@@ -279,10 +282,11 @@ class Pool(Base):
from airflow.models.taskinstance import TaskInstance # Avoid circular
import
return int(
- session.query(func.sum(TaskInstance.pool_slots))
- .filter(TaskInstance.pool == self.pool)
- .filter(TaskInstance.state == State.SCHEDULED)
- .scalar()
+ session.scalar(
+ select(func.sum(TaskInstance.pool_slots))
+ .filter(TaskInstance.pool == self.pool)
+ .filter(TaskInstance.state == State.SCHEDULED)
+ )
or 0
)
diff --git
a/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py
b/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py
index 4903a88e56..edba1ffd9c 100644
--- a/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py
+++ b/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py
@@ -385,6 +385,7 @@ class
TestGetMappedTaskInstances(TestMappedTaskInstanceEndpoint):
)
assert response.status_code == 200
assert response.json["total_entries"] == 0
+ assert response.json["task_instances"] == []
@provide_session
def test_mapped_task_instances_with_state(self, one_task_with_mapped_tis,
session):
@@ -402,6 +403,7 @@ class
TestGetMappedTaskInstances(TestMappedTaskInstanceEndpoint):
)
assert response.status_code == 200
assert response.json["total_entries"] == 0
+ assert response.json["task_instances"] == []
@provide_session
def test_mapped_task_instances_with_pool(self, one_task_with_mapped_tis,
session):
@@ -420,6 +422,7 @@ class
TestGetMappedTaskInstances(TestMappedTaskInstanceEndpoint):
)
assert response.status_code == 200
assert response.json["total_entries"] == 0
+ assert response.json["task_instances"] == []
@provide_session
def test_mapped_task_instances_with_queue(self, one_task_with_mapped_tis,
session):
@@ -437,6 +440,7 @@ class
TestGetMappedTaskInstances(TestMappedTaskInstanceEndpoint):
)
assert response.status_code == 200
assert response.json["total_entries"] == 0
+ assert response.json["task_instances"] == []
@provide_session
def test_mapped_task_instances_with_zero_mapped(self,
one_task_with_zero_mapped_tis, session):
@@ -446,4 +450,4 @@ class
TestGetMappedTaskInstances(TestMappedTaskInstanceEndpoint):
)
assert response.status_code == 200
assert response.json["total_entries"] == 0
- assert len(response.json["task_instances"]) == 0
+ assert response.json["task_instances"] == []
diff --git a/tests/models/test_dagwarning.py b/tests/models/test_dagwarning.py
index 7ae2962b1f..06b14b56ea 100644
--- a/tests/models/test_dagwarning.py
+++ b/tests/models/test_dagwarning.py
@@ -17,6 +17,7 @@
from __future__ import annotations
+from unittest import mock
from unittest.mock import MagicMock
from sqlalchemy.exc import OperationalError
@@ -52,18 +53,18 @@ class TestDagWarning:
assert len(remaining_dag_warnings) == 1
assert remaining_dag_warnings[0].dag_id == "dag_2"
- def test_retry_purge_inactive_dag_warnings(self):
+ @mock.patch("airflow.models.dagwarning.delete")
+ def test_retry_purge_inactive_dag_warnings(self, delete_mock):
"""
Test that the purge_inactive_dag_warnings method calls the delete
method twice
if the query throws an operationalError on the first call and works on
the second attempt
"""
self.session_mock = MagicMock()
- self.delete_mock = MagicMock()
- self.session_mock.query.return_value.filter.return_value.delete =
self.delete_mock
- self.delete_mock.side_effect = [OperationalError(None, None, "database
timeout"), None]
+ self.session_mock.execute.side_effect = [OperationalError(None, None,
"database timeout"), None]
DagWarning.purge_inactive_dag_warnings(self.session_mock)
# Assert that the delete method was called twice
- assert self.delete_mock.call_count == 2
+ assert delete_mock.call_count == 2
+ assert self.session_mock.execute.call_count == 2