This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 83f5950976 Refactor Sqlalchemy queries to 2.0 styles (Part 2) (#31772)
83f5950976 is described below

commit 83f595097606bbe48d716b1c94eace7bd6ac4e77
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Tue Jun 27 15:09:05 2023 +0100

    Refactor Sqlalchemy queries to 2.0 styles (Part 2) (#31772)
    
    * Refactor Sqlalchemy queries to 2.0 styles (Part 2)
    
    This is a continuation of the effort to refactor the queries to sqlalchemy 
2.0 style
    
    * Fix total_entries counting logic
    
    The count should happen *after* we apply the filters, and is (slightly)
    different from the automatic 404 case.
    
    * Update queries in models/
    
    * fixup! Update queries in models/
    
    * Add triggerer_job in query
    
    * fixup! fixup! Update queries in models/
    
    * Revert changes to dagwarning.py
    
    * Use session.scalar
    
    * Convert DagWarning to SQLAlchemy 2.x syntax
    
    * Fix dagwarning
    
    * fix typing and remove todo comment
    
    * Apply suggestions from code review
    
    Co-authored-by: Tzu-ping Chung <[email protected]>
    
    * Use scalars instead of execute and fix typing
    
    * Apply suggestions from code review
    
    Co-authored-by: Tzu-ping Chung <[email protected]>
    
    * Apply suggestions from code review
    
    Co-authored-by: Tzu-ping Chung <[email protected]>
    
    * remove all
    
    * Remove one_or_none where possible
    
    ---------
    
    Co-authored-by: Tzu-ping Chung <[email protected]>
---
 .../api_connexion/endpoints/connection_endpoint.py |  17 +-
 airflow/api_connexion/endpoints/dag_endpoint.py    |  44 +--
 .../api_connexion/endpoints/dag_run_endpoint.py    |  78 ++---
 .../endpoints/dag_warning_endpoint.py              |  11 +-
 .../api_connexion/endpoints/dataset_endpoint.py    |  36 ++-
 .../api_connexion/endpoints/event_log_endpoint.py  |   8 +-
 .../api_connexion/endpoints/extra_link_endpoint.py |   7 +-
 .../endpoints/import_error_endpoint.py             |   8 +-
 airflow/api_connexion/endpoints/log_endpoint.py    |  12 +-
 airflow/api_connexion/endpoints/pool_endpoint.py   |  12 +-
 .../endpoints/role_and_permission_endpoint.py      |  18 +-
 .../endpoints/task_instance_endpoint.py            | 141 +++++----
 airflow/api_connexion/endpoints/user_endpoint.py   |   8 +-
 .../api_connexion/endpoints/variable_endpoint.py   |  16 +-
 airflow/api_connexion/endpoints/xcom_endpoint.py   |  30 +-
 airflow/api_connexion/parameters.py                |   6 +-
 airflow/models/abstractoperator.py                 |  35 ++-
 airflow/models/baseoperator.py                     |  34 +--
 airflow/models/dag.py                              | 333 +++++++++++----------
 airflow/models/dagcode.py                          |  16 +-
 airflow/models/dagrun.py                           | 183 +++++------
 airflow/models/dagwarning.py                       |  11 +-
 airflow/models/pool.py                             |  64 ++--
 .../test_mapped_task_instance_endpoint.py          |   6 +-
 tests/models/test_dagwarning.py                    |  11 +-
 25 files changed, 597 insertions(+), 548 deletions(-)

diff --git a/airflow/api_connexion/endpoints/connection_endpoint.py 
b/airflow/api_connexion/endpoints/connection_endpoint.py
index 641c3288d3..737ee54d6c 100644
--- a/airflow/api_connexion/endpoints/connection_endpoint.py
+++ b/airflow/api_connexion/endpoints/connection_endpoint.py
@@ -22,7 +22,7 @@ from http import HTTPStatus
 from connexion import NoContent
 from flask import request
 from marshmallow import ValidationError
-from sqlalchemy import func
+from sqlalchemy import func, select
 from sqlalchemy.orm import Session
 
 from airflow.api_connexion import security
@@ -58,7 +58,7 @@ RESOURCE_EVENT_PREFIX = "connection"
 )
 def delete_connection(*, connection_id: str, session: Session = NEW_SESSION) 
-> APIResponse:
     """Delete a connection entry."""
-    connection = 
session.query(Connection).filter_by(conn_id=connection_id).one_or_none()
+    connection = 
session.scalar(select(Connection).filter_by(conn_id=connection_id))
     if connection is None:
         raise NotFound(
             "Connection not found",
@@ -72,7 +72,7 @@ def delete_connection(*, connection_id: str, session: Session 
= NEW_SESSION) ->
 @provide_session
 def get_connection(*, connection_id: str, session: Session = NEW_SESSION) -> 
APIResponse:
     """Get a connection entry."""
-    connection = session.query(Connection).filter(Connection.conn_id == 
connection_id).one_or_none()
+    connection = session.scalar(select(Connection).where(Connection.conn_id == 
connection_id))
     if connection is None:
         raise NotFound(
             "Connection not found",
@@ -95,10 +95,10 @@ def get_connections(
     to_replace = {"connection_id": "conn_id"}
     allowed_filter_attrs = ["connection_id", "conn_type", "description", 
"host", "port", "id"]
 
-    total_entries = session.query(func.count(Connection.id)).scalar()
-    query = session.query(Connection)
+    total_entries = 
session.execute(select(func.count(Connection.id))).scalar_one()
+    query = select(Connection)
     query = apply_sorting(query, order_by, to_replace, allowed_filter_attrs)
-    connections = query.offset(offset).limit(limit).all()
+    connections = session.scalars(query.offset(offset).limit(limit)).all()
     return connection_collection_schema.dump(
         ConnectionCollection(connections=connections, 
total_entries=total_entries)
     )
@@ -125,7 +125,7 @@ def patch_connection(
         # If validation get to here, it is extra field validation.
         raise BadRequest(detail=str(err.messages))
     non_update_fields = ["connection_id", "conn_id"]
-    connection = 
session.query(Connection).filter_by(conn_id=connection_id).first()
+    connection = 
session.scalar(select(Connection).filter_by(conn_id=connection_id).limit(1))
     if connection is None:
         raise NotFound(
             "Connection not found",
@@ -162,8 +162,7 @@ def post_connection(*, session: Session = NEW_SESSION) -> 
APIResponse:
         helpers.validate_key(conn_id, max_length=200)
     except Exception as e:
         raise BadRequest(detail=str(e))
-    query = session.query(Connection)
-    connection = query.filter_by(conn_id=conn_id).first()
+    connection = 
session.scalar(select(Connection).filter_by(conn_id=conn_id).limit(1))
     if not connection:
         connection = Connection(**data)
         session.add(connection)
diff --git a/airflow/api_connexion/endpoints/dag_endpoint.py 
b/airflow/api_connexion/endpoints/dag_endpoint.py
index 18dcb45be6..55eb971c08 100644
--- a/airflow/api_connexion/endpoints/dag_endpoint.py
+++ b/airflow/api_connexion/endpoints/dag_endpoint.py
@@ -22,6 +22,7 @@ from typing import Collection
 from connexion import NoContent
 from flask import g, request
 from marshmallow import ValidationError
+from sqlalchemy import func, select, update
 from sqlalchemy.orm import Session
 from sqlalchemy.sql.expression import or_
 
@@ -47,7 +48,7 @@ from airflow.utils.session import NEW_SESSION, provide_session
 @provide_session
 def get_dag(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse:
     """Get basic information about a DAG."""
-    dag = session.query(DagModel).filter(DagModel.dag_id == 
dag_id).one_or_none()
+    dag = session.scalar(select(DagModel).where(DagModel.dag_id == dag_id))
 
     if dag is None:
         raise NotFound("DAG not found", detail=f"The DAG with dag_id: {dag_id} 
was not found")
@@ -80,27 +81,27 @@ def get_dags(
 ) -> APIResponse:
     """Get all DAGs."""
     allowed_attrs = ["dag_id"]
-    dags_query = session.query(DagModel).filter(~DagModel.is_subdag)
+    dags_query = select(DagModel).where(~DagModel.is_subdag)
     if only_active:
-        dags_query = dags_query.filter(DagModel.is_active)
+        dags_query = dags_query.where(DagModel.is_active)
     if paused is not None:
         if paused:
-            dags_query = dags_query.filter(DagModel.is_paused)
+            dags_query = dags_query.where(DagModel.is_paused)
         else:
-            dags_query = dags_query.filter(~DagModel.is_paused)
+            dags_query = dags_query.where(~DagModel.is_paused)
     if dag_id_pattern:
-        dags_query = 
dags_query.filter(DagModel.dag_id.ilike(f"%{dag_id_pattern}%"))
+        dags_query = 
dags_query.where(DagModel.dag_id.ilike(f"%{dag_id_pattern}%"))
 
     readable_dags = 
get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user)
 
-    dags_query = dags_query.filter(DagModel.dag_id.in_(readable_dags))
+    dags_query = dags_query.where(DagModel.dag_id.in_(readable_dags))
     if tags:
         cond = [DagModel.tags.any(DagTag.name == tag) for tag in tags]
-        dags_query = dags_query.filter(or_(*cond))
+        dags_query = dags_query.where(or_(*cond))
 
-    total_entries = dags_query.count()
+    total_entries = 
session.scalar(select(func.count()).select_from(dags_query))
     dags_query = apply_sorting(dags_query, order_by, {}, allowed_attrs)
-    dags = dags_query.offset(offset).limit(limit).all()
+    dags = session.scalars(dags_query.offset(offset).limit(limit)).all()
 
     return dags_collection_schema.dump(DAGCollection(dags=dags, 
total_entries=total_entries))
 
@@ -119,7 +120,7 @@ def patch_dag(*, dag_id: str, update_mask: UpdateMask = 
None, session: Session =
             raise BadRequest(detail="Only `is_paused` field can be updated 
through the REST API")
         patch_body_[update_mask[0]] = patch_body[update_mask[0]]
         patch_body = patch_body_
-    dag = session.query(DagModel).filter(DagModel.dag_id == 
dag_id).one_or_none()
+    dag = session.scalar(select(DagModel).where(DagModel.dag_id == dag_id))
     if not dag:
         raise NotFound(f"Dag with id: '{dag_id}' not found")
     dag.is_paused = patch_body["is_paused"]
@@ -144,27 +145,30 @@ def patch_dags(limit, session, offset=0, 
only_active=True, tags=None, dag_id_pat
         patch_body_[update_mask] = patch_body[update_mask]
         patch_body = patch_body_
     if only_active:
-        dags_query = session.query(DagModel).filter(~DagModel.is_subdag, 
DagModel.is_active)
+        dags_query = select(DagModel).where(~DagModel.is_subdag, 
DagModel.is_active)
     else:
-        dags_query = session.query(DagModel).filter(~DagModel.is_subdag)
+        dags_query = select(DagModel).where(~DagModel.is_subdag)
 
     if dag_id_pattern == "~":
         dag_id_pattern = "%"
-    dags_query = 
dags_query.filter(DagModel.dag_id.ilike(f"%{dag_id_pattern}%"))
+    dags_query = dags_query.where(DagModel.dag_id.ilike(f"%{dag_id_pattern}%"))
     editable_dags = 
get_airflow_app().appbuilder.sm.get_editable_dag_ids(g.user)
 
-    dags_query = dags_query.filter(DagModel.dag_id.in_(editable_dags))
+    dags_query = dags_query.where(DagModel.dag_id.in_(editable_dags))
     if tags:
         cond = [DagModel.tags.any(DagTag.name == tag) for tag in tags]
-        dags_query = dags_query.filter(or_(*cond))
+        dags_query = dags_query.where(or_(*cond))
 
-    total_entries = dags_query.count()
+    total_entries = 
session.scalar(select(func.count()).select_from(dags_query))
 
-    dags = 
dags_query.order_by(DagModel.dag_id).offset(offset).limit(limit).all()
+    dags = 
session.scalars(dags_query.order_by(DagModel.dag_id).offset(offset).limit(limit)).all()
 
     dags_to_update = {dag.dag_id for dag in dags}
-    session.query(DagModel).filter(DagModel.dag_id.in_(dags_to_update)).update(
-        {DagModel.is_paused: patch_body["is_paused"]}, 
synchronize_session="fetch"
+    session.execute(
+        update(DagModel)
+        .where(DagModel.dag_id.in_(dags_to_update))
+        .values(is_paused=patch_body["is_paused"])
+        .execution_options(synchronize_session="fetch")
     )
 
     session.flush()
diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py 
b/airflow/api_connexion/endpoints/dag_run_endpoint.py
index 95e5913ebd..f62b28273f 100644
--- a/airflow/api_connexion/endpoints/dag_run_endpoint.py
+++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py
@@ -23,8 +23,9 @@ from connexion import NoContent
 from flask import g
 from flask_login import current_user
 from marshmallow import ValidationError
-from sqlalchemy import delete, or_
-from sqlalchemy.orm import Query, Session
+from sqlalchemy import delete, func, or_, select
+from sqlalchemy.orm import Session
+from sqlalchemy.sql import Select
 
 from airflow.api.common.mark_tasks import (
     set_dag_run_state_to_failed,
@@ -91,7 +92,7 @@ def delete_dag_run(*, dag_id: str, dag_run_id: str, session: 
Session = NEW_SESSI
 @provide_session
 def get_dag_run(*, dag_id: str, dag_run_id: str, session: Session = 
NEW_SESSION) -> APIResponse:
     """Get a DAG Run."""
-    dag_run = session.query(DagRun).filter(DagRun.dag_id == dag_id, 
DagRun.run_id == dag_run_id).one_or_none()
+    dag_run = session.scalar(select(DagRun).where(DagRun.dag_id == dag_id, 
DagRun.run_id == dag_run_id))
     if dag_run is None:
         raise NotFound(
             "DAGRun not found",
@@ -112,13 +113,11 @@ def get_upstream_dataset_events(
     *, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION
 ) -> APIResponse:
     """If dag run is dataset-triggered, return the dataset events that 
triggered it."""
-    dag_run: DagRun | None = (
-        session.query(DagRun)
-        .filter(
+    dag_run: DagRun | None = session.scalar(
+        select(DagRun).where(
             DagRun.dag_id == dag_id,
             DagRun.run_id == dag_run_id,
         )
-        .one_or_none()
     )
     if dag_run is None:
         raise NotFound(
@@ -132,7 +131,7 @@ def get_upstream_dataset_events(
 
 
 def _fetch_dag_runs(
-    query: Query,
+    query: Select,
     *,
     end_date_gte: str | None,
     end_date_lte: str | None,
@@ -145,28 +144,29 @@ def _fetch_dag_runs(
     limit: int | None,
     offset: int | None,
     order_by: str,
+    session: Session,
 ) -> tuple[list[DagRun], int]:
     if start_date_gte:
-        query = query.filter(DagRun.start_date >= start_date_gte)
+        query = query.where(DagRun.start_date >= start_date_gte)
     if start_date_lte:
-        query = query.filter(DagRun.start_date <= start_date_lte)
+        query = query.where(DagRun.start_date <= start_date_lte)
     # filter execution date
     if execution_date_gte:
-        query = query.filter(DagRun.execution_date >= execution_date_gte)
+        query = query.where(DagRun.execution_date >= execution_date_gte)
     if execution_date_lte:
-        query = query.filter(DagRun.execution_date <= execution_date_lte)
+        query = query.where(DagRun.execution_date <= execution_date_lte)
     # filter end date
     if end_date_gte:
-        query = query.filter(DagRun.end_date >= end_date_gte)
+        query = query.where(DagRun.end_date >= end_date_gte)
     if end_date_lte:
-        query = query.filter(DagRun.end_date <= end_date_lte)
+        query = query.where(DagRun.end_date <= end_date_lte)
     # filter updated at
     if updated_at_gte:
-        query = query.filter(DagRun.updated_at >= updated_at_gte)
+        query = query.where(DagRun.updated_at >= updated_at_gte)
     if updated_at_lte:
-        query = query.filter(DagRun.updated_at <= updated_at_lte)
+        query = query.where(DagRun.updated_at <= updated_at_lte)
 
-    total_entries = query.count()
+    total_entries = session.scalar(select(func.count()).select_from(query))
     to_replace = {"dag_run_id": "run_id"}
     allowed_filter_attrs = [
         "id",
@@ -181,7 +181,7 @@ def _fetch_dag_runs(
         "conf",
     ]
     query = apply_sorting(query, order_by, to_replace, allowed_filter_attrs)
-    return query.offset(offset).limit(limit).all(), total_entries
+    return session.scalars(query.offset(offset).limit(limit)).all(), 
total_entries
 
 
 @security.requires_access(
@@ -222,17 +222,17 @@ def get_dag_runs(
     session: Session = NEW_SESSION,
 ):
     """Get all DAG Runs."""
-    query = session.query(DagRun)
+    query = select(DagRun)
 
     #  This endpoint allows specifying ~ as the dag_id to retrieve DAG Runs 
for all DAGs.
     if dag_id == "~":
         appbuilder = get_airflow_app().appbuilder
-        query = 
query.filter(DagRun.dag_id.in_(appbuilder.sm.get_readable_dag_ids(g.user)))
+        query = 
query.where(DagRun.dag_id.in_(appbuilder.sm.get_readable_dag_ids(g.user)))
     else:
-        query = query.filter(DagRun.dag_id == dag_id)
+        query = query.where(DagRun.dag_id == dag_id)
 
     if state:
-        query = query.filter(DagRun.state.in_(state))
+        query = query.where(DagRun.state.in_(state))
 
     dag_run, total_entries = _fetch_dag_runs(
         query,
@@ -247,6 +247,7 @@ def get_dag_runs(
         limit=limit,
         offset=offset,
         order_by=order_by,
+        session=session,
     )
     return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_run, 
total_entries=total_entries))
 
@@ -268,16 +269,16 @@ def get_dag_runs_batch(*, session: Session = NEW_SESSION) 
-> APIResponse:
 
     appbuilder = get_airflow_app().appbuilder
     readable_dag_ids = appbuilder.sm.get_readable_dag_ids(g.user)
-    query = session.query(DagRun)
+    query = select(DagRun)
     if data.get("dag_ids"):
         dag_ids = set(data["dag_ids"]) & set(readable_dag_ids)
-        query = query.filter(DagRun.dag_id.in_(dag_ids))
+        query = query.where(DagRun.dag_id.in_(dag_ids))
     else:
-        query = query.filter(DagRun.dag_id.in_(readable_dag_ids))
+        query = query.where(DagRun.dag_id.in_(readable_dag_ids))
 
     states = data.get("states")
     if states:
-        query = query.filter(DagRun.state.in_(states))
+        query = query.where(DagRun.state.in_(states))
 
     dag_runs, total_entries = _fetch_dag_runs(
         query,
@@ -290,6 +291,7 @@ def get_dag_runs_batch(*, session: Session = NEW_SESSION) 
-> APIResponse:
         limit=data["page_limit"],
         offset=data["page_offset"],
         order_by=data.get("order_by", "id"),
+        session=session,
     )
 
     return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_runs, 
total_entries=total_entries))
@@ -310,7 +312,7 @@ def get_dag_runs_batch(*, session: Session = NEW_SESSION) 
-> APIResponse:
 )
 def post_dag_run(*, dag_id: str, session: Session = NEW_SESSION) -> 
APIResponse:
     """Trigger a DAG."""
-    dm = session.query(DagModel).filter(DagModel.is_active, DagModel.dag_id == 
dag_id).first()
+    dm = session.scalar(select(DagModel).where(DagModel.is_active, 
DagModel.dag_id == dag_id).limit(1))
     if not dm:
         raise NotFound(title="DAG not found", detail=f"DAG with dag_id: 
'{dag_id}' not found")
     if dm.has_import_errors:
@@ -325,13 +327,13 @@ def post_dag_run(*, dag_id: str, session: Session = 
NEW_SESSION) -> APIResponse:
 
     logical_date = pendulum.instance(post_body["execution_date"])
     run_id = post_body["run_id"]
-    dagrun_instance = (
-        session.query(DagRun)
-        .filter(
+    dagrun_instance = session.scalar(
+        select(DagRun)
+        .where(
             DagRun.dag_id == dag_id,
             or_(DagRun.run_id == run_id, DagRun.execution_date == 
logical_date),
         )
-        .first()
+        .limit(1)
     )
     if not dagrun_instance:
         try:
@@ -375,8 +377,8 @@ def post_dag_run(*, dag_id: str, session: Session = 
NEW_SESSION) -> APIResponse:
 @provide_session
 def update_dag_run_state(*, dag_id: str, dag_run_id: str, session: Session = 
NEW_SESSION) -> APIResponse:
     """Set a state of a dag run."""
-    dag_run: DagRun | None = (
-        session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == 
dag_run_id).one_or_none()
+    dag_run: DagRun | None = session.scalar(
+        select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == 
dag_run_id)
     )
     if dag_run is None:
         error_message = f"Dag Run id {dag_run_id} not found in dag {dag_id}"
@@ -407,8 +409,8 @@ def update_dag_run_state(*, dag_id: str, dag_run_id: str, 
session: Session = NEW
 @provide_session
 def clear_dag_run(*, dag_id: str, dag_run_id: str, session: Session = 
NEW_SESSION) -> APIResponse:
     """Clear a dag run."""
-    dag_run: DagRun | None = (
-        session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == 
dag_run_id).one_or_none()
+    dag_run: DagRun | None = session.scalar(
+        select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == 
dag_run_id)
     )
     if dag_run is None:
         error_message = f"Dag Run id {dag_run_id} not found in dag   {dag_id}"
@@ -445,7 +447,7 @@ def clear_dag_run(*, dag_id: str, dag_run_id: str, session: 
Session = NEW_SESSIO
             include_parentdag=True,
             only_failed=False,
         )
-        dag_run = session.query(DagRun).filter(DagRun.id == dag_run.id).one()
+        dag_run = session.execute(select(DagRun).where(DagRun.id == 
dag_run.id)).scalar_one()
         return dagrun_schema.dump(dag_run)
 
 
@@ -458,8 +460,8 @@ def clear_dag_run(*, dag_id: str, dag_run_id: str, session: 
Session = NEW_SESSIO
 @provide_session
 def set_dag_run_note(*, dag_id: str, dag_run_id: str, session: Session = 
NEW_SESSION) -> APIResponse:
     """Set the note for a dag run."""
-    dag_run: DagRun | None = (
-        session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == 
dag_run_id).one_or_none()
+    dag_run: DagRun | None = session.scalar(
+        select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == 
dag_run_id)
     )
     if dag_run is None:
         error_message = f"Dag Run id {dag_run_id} not found in dag {dag_id}"
diff --git a/airflow/api_connexion/endpoints/dag_warning_endpoint.py 
b/airflow/api_connexion/endpoints/dag_warning_endpoint.py
index 5a73afd1a3..66aa1184a1 100644
--- a/airflow/api_connexion/endpoints/dag_warning_endpoint.py
+++ b/airflow/api_connexion/endpoints/dag_warning_endpoint.py
@@ -16,6 +16,7 @@
 # under the License.
 from __future__ import annotations
 
+from sqlalchemy import func, select
 from sqlalchemy.orm import Session
 
 from airflow.api_connexion import security
@@ -48,14 +49,14 @@ def get_dag_warnings(
     :param warning_type: the warning type to optionally filter by
     """
     allowed_filter_attrs = ["dag_id", "warning_type", "message", "timestamp"]
-    query = session.query(DagWarningModel)
+    query = select(DagWarningModel)
     if dag_id:
-        query = query.filter(DagWarningModel.dag_id == dag_id)
+        query = query.where(DagWarningModel.dag_id == dag_id)
     if warning_type:
-        query = query.filter(DagWarningModel.warning_type == warning_type)
-    total_entries = query.count()
+        query = query.where(DagWarningModel.warning_type == warning_type)
+    total_entries = session.scalar(select(func.count()).select_from(query))
     query = apply_sorting(query=query, order_by=order_by, 
allowed_attrs=allowed_filter_attrs)
-    dag_warnings = query.offset(offset).limit(limit).all()
+    dag_warnings = session.scalars(query.offset(offset).limit(limit)).all()
     return dag_warning_collection_schema.dump(
         DagWarningCollection(dag_warnings=dag_warnings, 
total_entries=total_entries)
     )
diff --git a/airflow/api_connexion/endpoints/dataset_endpoint.py 
b/airflow/api_connexion/endpoints/dataset_endpoint.py
index 42e8bb3c36..9f4fa443b9 100644
--- a/airflow/api_connexion/endpoints/dataset_endpoint.py
+++ b/airflow/api_connexion/endpoints/dataset_endpoint.py
@@ -16,7 +16,7 @@
 # under the License.
 from __future__ import annotations
 
-from sqlalchemy import func
+from sqlalchemy import func, select
 from sqlalchemy.orm import Session, joinedload, subqueryload
 
 from airflow.api_connexion import security
@@ -39,11 +39,10 @@ from airflow.utils.session import NEW_SESSION, 
provide_session
 @provide_session
 def get_dataset(uri: str, session: Session = NEW_SESSION) -> APIResponse:
     """Get a Dataset."""
-    dataset = (
-        session.query(DatasetModel)
-        .filter(DatasetModel.uri == uri)
+    dataset = session.scalar(
+        select(DatasetModel)
+        .where(DatasetModel.uri == uri)
         .options(joinedload(DatasetModel.consuming_dags), 
joinedload(DatasetModel.producing_tasks))
-        .one_or_none()
     )
     if not dataset:
         raise NotFound(
@@ -67,17 +66,16 @@ def get_datasets(
     """Get datasets."""
     allowed_attrs = ["id", "uri", "created_at", "updated_at"]
 
-    total_entries = session.query(func.count(DatasetModel.id)).scalar()
-    query = session.query(DatasetModel)
+    total_entries = session.scalars(select(func.count(DatasetModel.id))).one()
+    query = select(DatasetModel)
     if uri_pattern:
-        query = query.filter(DatasetModel.uri.ilike(f"%{uri_pattern}%"))
+        query = query.where(DatasetModel.uri.ilike(f"%{uri_pattern}%"))
     query = apply_sorting(query, order_by, {}, allowed_attrs)
-    datasets = (
+    datasets = session.scalars(
         query.options(subqueryload(DatasetModel.consuming_dags), 
subqueryload(DatasetModel.producing_tasks))
         .offset(offset)
         .limit(limit)
-        .all()
-    )
+    ).all()
     return dataset_collection_schema.dump(DatasetCollection(datasets=datasets, 
total_entries=total_entries))
 
 
@@ -99,24 +97,24 @@ def get_dataset_events(
     """Get dataset events."""
     allowed_attrs = ["source_dag_id", "source_task_id", "source_run_id", 
"source_map_index", "timestamp"]
 
-    query = session.query(DatasetEvent)
+    query = select(DatasetEvent)
 
     if dataset_id:
-        query = query.filter(DatasetEvent.dataset_id == dataset_id)
+        query = query.where(DatasetEvent.dataset_id == dataset_id)
     if source_dag_id:
-        query = query.filter(DatasetEvent.source_dag_id == source_dag_id)
+        query = query.where(DatasetEvent.source_dag_id == source_dag_id)
     if source_task_id:
-        query = query.filter(DatasetEvent.source_task_id == source_task_id)
+        query = query.where(DatasetEvent.source_task_id == source_task_id)
     if source_run_id:
-        query = query.filter(DatasetEvent.source_run_id == source_run_id)
+        query = query.where(DatasetEvent.source_run_id == source_run_id)
     if source_map_index:
-        query = query.filter(DatasetEvent.source_map_index == source_map_index)
+        query = query.where(DatasetEvent.source_map_index == source_map_index)
 
     query = query.options(subqueryload(DatasetEvent.created_dagruns))
 
-    total_entries = query.count()
+    total_entries = session.scalar(select(func.count()).select_from(query))
     query = apply_sorting(query, order_by, {}, allowed_attrs)
-    events = query.offset(offset).limit(limit).all()
+    events = session.scalars(query.offset(offset).limit(limit)).all()
     return dataset_event_collection_schema.dump(
         DatasetEventCollection(dataset_events=events, 
total_entries=total_entries)
     )
diff --git a/airflow/api_connexion/endpoints/event_log_endpoint.py 
b/airflow/api_connexion/endpoints/event_log_endpoint.py
index 335886189c..28615f6fec 100644
--- a/airflow/api_connexion/endpoints/event_log_endpoint.py
+++ b/airflow/api_connexion/endpoints/event_log_endpoint.py
@@ -16,7 +16,7 @@
 # under the License.
 from __future__ import annotations
 
-from sqlalchemy import func
+from sqlalchemy import func, select
 from sqlalchemy.orm import Session
 
 from airflow.api_connexion import security
@@ -65,10 +65,10 @@ def get_event_logs(
         "owner",
         "extra",
     ]
-    total_entries = session.query(func.count(Log.id)).scalar()
-    query = session.query(Log)
+    total_entries = session.scalars(func.count(Log.id)).one()
+    query = select(Log)
     query = apply_sorting(query, order_by, to_replace, allowed_filter_attrs)
-    event_logs = query.offset(offset).limit(limit).all()
+    event_logs = session.scalars(query.offset(offset).limit(limit)).all()
     return event_log_collection_schema.dump(
         EventLogCollection(event_logs=event_logs, total_entries=total_entries)
     )
diff --git a/airflow/api_connexion/endpoints/extra_link_endpoint.py 
b/airflow/api_connexion/endpoints/extra_link_endpoint.py
index 2b12667e7c..e28822522e 100644
--- a/airflow/api_connexion/endpoints/extra_link_endpoint.py
+++ b/airflow/api_connexion/endpoints/extra_link_endpoint.py
@@ -16,6 +16,7 @@
 # under the License.
 from __future__ import annotations
 
+from sqlalchemy import select
 from sqlalchemy.orm.session import Session
 
 from airflow import DAG
@@ -57,14 +58,12 @@ def get_extra_links(
     except TaskNotFound:
         raise NotFound("Task not found", detail=f'Task with ID = "{task_id}" 
not found')
 
-    ti = (
-        session.query(TaskInstance)
-        .filter(
+    ti = session.scalar(
+        select(TaskInstance).where(
             TaskInstance.dag_id == dag_id,
             TaskInstance.run_id == dag_run_id,
             TaskInstance.task_id == task_id,
         )
-        .one_or_none()
     )
 
     if not ti:
diff --git a/airflow/api_connexion/endpoints/import_error_endpoint.py 
b/airflow/api_connexion/endpoints/import_error_endpoint.py
index 3ffd8f11a5..5e871e08bd 100644
--- a/airflow/api_connexion/endpoints/import_error_endpoint.py
+++ b/airflow/api_connexion/endpoints/import_error_endpoint.py
@@ -16,7 +16,7 @@
 # under the License.
 from __future__ import annotations
 
-from sqlalchemy import func
+from sqlalchemy import func, select
 from sqlalchemy.orm import Session
 
 from airflow.api_connexion import security
@@ -60,10 +60,10 @@ def get_import_errors(
     """Get all import errors."""
     to_replace = {"import_error_id": "id"}
     allowed_filter_attrs = ["import_error_id", "timestamp", "filename"]
-    total_entries = session.query(func.count(ImportErrorModel.id)).scalar()
-    query = session.query(ImportErrorModel)
+    total_entries = session.scalars(func.count(ImportErrorModel.id)).one()
+    query = select(ImportErrorModel)
     query = apply_sorting(query, order_by, to_replace, allowed_filter_attrs)
-    import_errors = query.offset(offset).limit(limit).all()
+    import_errors = session.scalars(query.offset(offset).limit(limit)).all()
     return import_error_collection_schema.dump(
         ImportErrorCollection(import_errors=import_errors, 
total_entries=total_entries)
     )
diff --git a/airflow/api_connexion/endpoints/log_endpoint.py 
b/airflow/api_connexion/endpoints/log_endpoint.py
index a81258ce9d..8df712e67c 100644
--- a/airflow/api_connexion/endpoints/log_endpoint.py
+++ b/airflow/api_connexion/endpoints/log_endpoint.py
@@ -21,6 +21,7 @@ from typing import Any
 from flask import Response, request
 from itsdangerous.exc import BadSignature
 from itsdangerous.url_safe import URLSafeSerializer
+from sqlalchemy import select
 from sqlalchemy.orm import Session, joinedload
 
 from airflow.api_connexion import security
@@ -28,7 +29,7 @@ from airflow.api_connexion.exceptions import BadRequest, 
NotFound
 from airflow.api_connexion.schemas.log_schema import LogResponseObject, 
logs_schema
 from airflow.api_connexion.types import APIResponse
 from airflow.exceptions import TaskNotFound
-from airflow.models import TaskInstance
+from airflow.models import TaskInstance, Trigger
 from airflow.security import permissions
 from airflow.utils.airflow_flask_app import get_airflow_app
 from airflow.utils.log.log_reader import TaskLogReader
@@ -77,18 +78,17 @@ def get_log(
     if not task_log_reader.supports_read:
         raise BadRequest("Task log handler does not support read logs.")
     query = (
-        session.query(TaskInstance)
-        .filter(
+        select(TaskInstance)
+        .where(
             TaskInstance.task_id == task_id,
             TaskInstance.dag_id == dag_id,
             TaskInstance.run_id == dag_run_id,
             TaskInstance.map_index == map_index,
         )
         .join(TaskInstance.dag_run)
-        .options(joinedload("trigger"))
-        .options(joinedload("trigger.triggerer_job"))
+        
.options(joinedload(TaskInstance.trigger).joinedload(Trigger.triggerer_job))
     )
-    ti = query.one_or_none()
+    ti = session.scalar(query)
     if ti is None:
         metadata["end_of_log"] = True
         raise NotFound(title="TaskInstance not found")
diff --git a/airflow/api_connexion/endpoints/pool_endpoint.py 
b/airflow/api_connexion/endpoints/pool_endpoint.py
index a760ee4d83..3668741c1e 100644
--- a/airflow/api_connexion/endpoints/pool_endpoint.py
+++ b/airflow/api_connexion/endpoints/pool_endpoint.py
@@ -20,7 +20,7 @@ from http import HTTPStatus
 
 from flask import Response
 from marshmallow import ValidationError
-from sqlalchemy import delete, func
+from sqlalchemy import delete, func, select
 from sqlalchemy.exc import IntegrityError
 from sqlalchemy.orm import Session
 
@@ -52,7 +52,7 @@ def delete_pool(*, pool_name: str, session: Session = 
NEW_SESSION) -> APIRespons
 @provide_session
 def get_pool(*, pool_name: str, session: Session = NEW_SESSION) -> APIResponse:
     """Get a pool."""
-    obj = session.query(Pool).filter(Pool.pool == pool_name).one_or_none()
+    obj = session.scalar(select(Pool).where(Pool.pool == pool_name))
     if obj is None:
         raise NotFound(detail=f"Pool with name:'{pool_name}' not found")
     return pool_schema.dump(obj)
@@ -71,10 +71,10 @@ def get_pools(
     """Get all pools."""
     to_replace = {"name": "pool"}
     allowed_filter_attrs = ["name", "slots", "id"]
-    total_entries = session.query(func.count(Pool.id)).scalar()
-    query = session.query(Pool)
+    total_entries = session.scalars(func.count(Pool.id)).one()
+    query = select(Pool)
     query = apply_sorting(query, order_by, to_replace, allowed_filter_attrs)
-    pools = query.offset(offset).limit(limit).all()
+    pools = session.scalars(query.offset(offset).limit(limit)).all()
     return pool_collection_schema.dump(PoolCollection(pools=pools, 
total_entries=total_entries))
 
 
@@ -98,7 +98,7 @@ def patch_pool(
     except KeyError:
         pass
 
-    pool = session.query(Pool).filter(Pool.pool == pool_name).first()
+    pool = session.scalar(select(Pool).where(Pool.pool == pool_name).limit(1))
     if not pool:
         raise NotFound(detail=f"Pool with name:'{pool_name}' not found")
 
diff --git a/airflow/api_connexion/endpoints/role_and_permission_endpoint.py 
b/airflow/api_connexion/endpoints/role_and_permission_endpoint.py
index 4ed40caae5..609c45893f 100644
--- a/airflow/api_connexion/endpoints/role_and_permission_endpoint.py
+++ b/airflow/api_connexion/endpoints/role_and_permission_endpoint.py
@@ -21,7 +21,7 @@ from http import HTTPStatus
 from connexion import NoContent
 from flask import request
 from marshmallow import ValidationError
-from sqlalchemy import asc, desc, func
+from sqlalchemy import asc, desc, func, select
 
 from airflow.api_connexion import security
 from airflow.api_connexion.exceptions import AlreadyExists, BadRequest, 
NotFound
@@ -69,7 +69,7 @@ def get_roles(*, order_by: str = "name", limit: int, offset: 
int | None = None)
     """Get roles."""
     appbuilder = get_airflow_app().appbuilder
     session = appbuilder.get_session
-    total_entries = session.query(func.count(Role.id)).scalar()
+    total_entries = session.scalars(select(func.count(Role.id))).one()
     direction = desc if order_by.startswith("-") else asc
     to_replace = {"role_id": "id"}
     order_param = order_by.strip("-")
@@ -81,8 +81,12 @@ def get_roles(*, order_by: str = "name", limit: int, offset: 
int | None = None)
             f"the attribute does not exist on the model"
         )
 
-    query = session.query(Role)
-    roles = query.order_by(direction(getattr(Role, 
order_param))).offset(offset).limit(limit).all()
+    query = select(Role)
+    roles = (
+        session.scalars(query.order_by(direction(getattr(Role, 
order_param))).offset(offset).limit(limit))
+        .unique()
+        .all()
+    )
 
     return role_collection_schema.dump(RoleCollection(roles=roles, 
total_entries=total_entries))
 
@@ -92,9 +96,9 @@ def get_roles(*, order_by: str = "name", limit: int, offset: 
int | None = None)
 def get_permissions(*, limit: int, offset: int | None = None) -> APIResponse:
     """Get permissions."""
     session = get_airflow_app().appbuilder.get_session
-    total_entries = session.query(func.count(Action.id)).scalar()
-    query = session.query(Action)
-    actions = query.offset(offset).limit(limit).all()
+    total_entries = session.scalars(select(func.count(Action.id))).one()
+    query = select(Action)
+    actions = session.scalars(query.offset(offset).limit(limit)).all()
     return action_collection_schema.dump(ActionCollection(actions=actions, 
total_entries=total_entries))
 
 
diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py 
b/airflow/api_connexion/endpoints/task_instance_endpoint.py
index 533d97c858..3028b0bb73 100644
--- a/airflow/api_connexion/endpoints/task_instance_endpoint.py
+++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py
@@ -19,11 +19,10 @@ from __future__ import annotations
 from typing import Any, Iterable, TypeVar
 
 from marshmallow import ValidationError
-from sqlalchemy import and_, func, or_
+from sqlalchemy import and_, func, or_, select
 from sqlalchemy.exc import MultipleResultsFound
 from sqlalchemy.orm import Session, joinedload
-from sqlalchemy.orm.query import Query
-from sqlalchemy.sql import ClauseElement
+from sqlalchemy.sql import ClauseElement, Select
 
 from airflow.api_connexion import security
 from airflow.api_connexion.endpoints.request_dict import get_json_request_dict
@@ -72,8 +71,8 @@ def get_task_instance(
 ) -> APIResponse:
     """Get task instance."""
     query = (
-        session.query(TI)
-        .filter(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id == 
task_id)
+        select(TI)
+        .where(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id == 
task_id)
         .join(TI.dag_run)
         .outerjoin(
             SlaMiss,
@@ -83,12 +82,12 @@ def get_task_instance(
                 SlaMiss.task_id == TI.task_id,
             ),
         )
-        .add_entity(SlaMiss)
+        .add_columns(SlaMiss)
         .options(joinedload(TI.rendered_task_instance_fields))
     )
 
     try:
-        task_instance = query.one_or_none()
+        task_instance = session.execute(query).one_or_none()
     except MultipleResultsFound:
         raise NotFound(
             "Task instance not found", detail="Task instance is mapped, add 
the map_index value to the URL"
@@ -121,10 +120,8 @@ def get_mapped_task_instance(
 ) -> APIResponse:
     """Get task instance."""
     query = (
-        session.query(TI)
-        .filter(
-            TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id == 
task_id, TI.map_index == map_index
-        )
+        select(TI)
+        .where(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id == 
task_id, TI.map_index == map_index)
         .join(TI.dag_run)
         .outerjoin(
             SlaMiss,
@@ -134,10 +131,11 @@ def get_mapped_task_instance(
                 SlaMiss.task_id == TI.task_id,
             ),
         )
-        .add_entity(SlaMiss)
+        .add_columns(SlaMiss)
         .options(joinedload(TI.rendered_task_instance_fields))
     )
-    task_instance = query.one_or_none()
+    task_instance = session.execute(query).one_or_none()
+
     if task_instance is None:
         raise NotFound("Task instance not found")
 
@@ -192,13 +190,14 @@ def get_mapped_task_instances(
     states = _convert_state(state)
 
     base_query = (
-        session.query(TI)
-        .filter(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id == 
task_id, TI.map_index >= 0)
+        select(TI)
+        .where(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id == 
task_id, TI.map_index >= 0)
         .join(TI.dag_run)
     )
 
     # 0 can mean a mapped TI that expanded to an empty list, so it is not an 
automatic 404
-    if base_query.with_entities(func.count("*")).scalar() == 0:
+    unfiltered_total_count = 
session.execute(select(func.count("*")).select_from(base_query)).scalar()
+    if unfiltered_total_count == 0:
         dag = get_airflow_app().dag_bag.get_dag(dag_id)
         if not dag:
             error_message = f"DAG {dag_id} not found"
@@ -212,50 +211,54 @@ def get_mapped_task_instances(
             raise NotFound(error_message)
 
     # Other search criteria
-    query = _apply_range_filter(
+    base_query = _apply_range_filter(
         base_query,
         key=DR.execution_date,
         value_range=(execution_date_gte, execution_date_lte),
     )
-    query = _apply_range_filter(query, key=TI.start_date, 
value_range=(start_date_gte, start_date_lte))
-    query = _apply_range_filter(query, key=TI.end_date, 
value_range=(end_date_gte, end_date_lte))
-    query = _apply_range_filter(query, key=TI.duration, 
value_range=(duration_gte, duration_lte))
-    query = _apply_range_filter(query, key=TI.updated_at, 
value_range=(updated_at_gte, updated_at_lte))
-    query = _apply_array_filter(query, key=TI.state, values=states)
-    query = _apply_array_filter(query, key=TI.pool, values=pool)
-    query = _apply_array_filter(query, key=TI.queue, values=queue)
+    base_query = _apply_range_filter(
+        base_query, key=TI.start_date, value_range=(start_date_gte, 
start_date_lte)
+    )
+    base_query = _apply_range_filter(base_query, key=TI.end_date, 
value_range=(end_date_gte, end_date_lte))
+    base_query = _apply_range_filter(base_query, key=TI.duration, 
value_range=(duration_gte, duration_lte))
+    base_query = _apply_range_filter(
+        base_query, key=TI.updated_at, value_range=(updated_at_gte, 
updated_at_lte)
+    )
+    base_query = _apply_array_filter(base_query, key=TI.state, values=states)
+    base_query = _apply_array_filter(base_query, key=TI.pool, values=pool)
+    base_query = _apply_array_filter(base_query, key=TI.queue, values=queue)
 
     # Count elements before joining extra columns
-    total_entries = query.with_entities(func.count("*")).scalar()
+    total_entries = 
session.execute(select(func.count("*")).select_from(base_query)).scalar()
 
     # Add SLA miss
-    query = (
-        query.join(
+    entry_query = (
+        base_query.outerjoin(
             SlaMiss,
             and_(
                 SlaMiss.dag_id == TI.dag_id,
                 SlaMiss.task_id == TI.task_id,
                 SlaMiss.execution_date == DR.execution_date,
             ),
-            isouter=True,
         )
-        .add_entity(SlaMiss)
+        .add_columns(SlaMiss)
         .options(joinedload(TI.rendered_task_instance_fields))
     )
 
     if order_by:
         if order_by == "state":
-            query = query.order_by(TI.state.asc(), TI.map_index.asc())
+            entry_query = entry_query.order_by(TI.state.asc(), 
TI.map_index.asc())
         elif order_by == "-state":
-            query = query.order_by(TI.state.desc(), TI.map_index.asc())
+            entry_query = entry_query.order_by(TI.state.desc(), 
TI.map_index.asc())
         elif order_by == "-map_index":
-            query = query.order_by(TI.map_index.desc())
+            entry_query = entry_query.order_by(TI.map_index.desc())
         else:
             raise BadRequest(detail=f"Ordering with '{order_by}' is not 
supported")
     else:
-        query = query.order_by(TI.map_index.asc())
+        entry_query = entry_query.order_by(TI.map_index.asc())
 
-    task_instances = query.offset(offset).limit(limit).all()
+    # using execute because we want the SlaMiss entity. Scalars don't return 
None for missing entities
+    task_instances = 
session.execute(entry_query.offset(offset).limit(limit)).all()
     return task_instance_collection_schema.dump(
         TaskInstanceCollection(task_instances=task_instances, 
total_entries=total_entries)
     )
@@ -267,19 +270,19 @@ def _convert_state(states: Iterable[str] | None) -> 
list[str | None] | None:
     return [State.NONE if s == "none" else s for s in states]
 
 
-def _apply_array_filter(query: Query, key: ClauseElement, values: 
Iterable[Any] | None) -> Query:
+def _apply_array_filter(query: Select, key: ClauseElement, values: 
Iterable[Any] | None) -> Select:
     if values is not None:
         cond = ((key == v) for v in values)
-        query = query.filter(or_(*cond))
+        query = query.where(or_(*cond))
     return query
 
 
-def _apply_range_filter(query: Query, key: ClauseElement, value_range: 
tuple[T, T]) -> Query:
+def _apply_range_filter(query: Select, key: ClauseElement, value_range: 
tuple[T, T]) -> Select:
     gte_value, lte_value = value_range
     if gte_value is not None:
-        query = query.filter(key >= gte_value)
+        query = query.where(key >= gte_value)
     if lte_value is not None:
-        query = query.filter(key <= lte_value)
+        query = query.where(key <= lte_value)
     return query
 
 
@@ -328,12 +331,12 @@ def get_task_instances(
     # Because state can be 'none'
     states = _convert_state(state)
 
-    base_query = session.query(TI).join(TI.dag_run)
+    base_query = select(TI).join(TI.dag_run)
 
     if dag_id != "~":
-        base_query = base_query.filter(TI.dag_id == dag_id)
+        base_query = base_query.where(TI.dag_id == dag_id)
     if dag_run_id != "~":
-        base_query = base_query.filter(TI.run_id == dag_run_id)
+        base_query = base_query.where(TI.run_id == dag_run_id)
     base_query = _apply_range_filter(
         base_query,
         key=DR.execution_date,
@@ -352,22 +355,26 @@ def get_task_instances(
     base_query = _apply_array_filter(base_query, key=TI.queue, values=queue)
 
     # Count elements before joining extra columns
-    total_entries = base_query.with_entities(func.count("*")).scalar()
+    count_query = select(func.count("*")).select_from(base_query)
+    total_entries = session.execute(count_query).scalar()
+
     # Add join
-    query = (
-        base_query.join(
+    entry_query = (
+        base_query.outerjoin(
             SlaMiss,
             and_(
                 SlaMiss.dag_id == TI.dag_id,
                 SlaMiss.task_id == TI.task_id,
                 SlaMiss.execution_date == DR.execution_date,
             ),
-            isouter=True,
         )
-        .add_entity(SlaMiss)
+        .add_columns(SlaMiss)
         .options(joinedload(TI.rendered_task_instance_fields))
+        .offset(offset)
+        .limit(limit)
     )
-    task_instances = query.offset(offset).limit(limit).all()
+    # using execute because we want the SlaMiss entity. Scalars don't return 
None for missing entities
+    task_instances = session.execute(entry_query).all()
     return task_instance_collection_schema.dump(
         TaskInstanceCollection(task_instances=task_instances, 
total_entries=total_entries)
     )
@@ -389,7 +396,7 @@ def get_task_instances_batch(session: Session = 
NEW_SESSION) -> APIResponse:
     except ValidationError as err:
         raise BadRequest(detail=str(err.messages))
     states = _convert_state(data["state"])
-    base_query = session.query(TI).join(TI.dag_run)
+    base_query = select(TI).join(TI.dag_run)
 
     base_query = _apply_array_filter(base_query, key=TI.dag_id, 
values=data["dag_ids"])
     base_query = _apply_range_filter(
@@ -413,7 +420,7 @@ def get_task_instances_batch(session: Session = 
NEW_SESSION) -> APIResponse:
     base_query = _apply_array_filter(base_query, key=TI.queue, 
values=data["queue"])
 
     # Count elements before joining extra columns
-    total_entries = base_query.with_entities(func.count("*")).scalar()
+    total_entries = 
session.execute(select(func.count("*")).select_from(base_query)).scalar()
     # Add join
     base_query = base_query.join(
         SlaMiss,
@@ -423,9 +430,10 @@ def get_task_instances_batch(session: Session = 
NEW_SESSION) -> APIResponse:
             SlaMiss.execution_date == DR.execution_date,
         ),
         isouter=True,
-    ).add_entity(SlaMiss)
+    ).add_columns(SlaMiss)
     ti_query = base_query.options(joinedload(TI.rendered_task_instance_fields))
-    task_instances = ti_query.all()
+    # using execute because we want the SlaMiss entity. Scalars don't return 
None for missing entities
+    task_instances = session.execute(ti_query).all()
 
     return task_instance_collection_schema.dump(
         TaskInstanceCollection(task_instances=task_instances, 
total_entries=total_entries)
@@ -461,9 +469,7 @@ def post_clear_task_instances(*, dag_id: str, session: 
Session = NEW_SESSION) ->
     downstream = data.pop("include_downstream", False)
     upstream = data.pop("include_upstream", False)
     if dag_run_id is not None:
-        dag_run: DR | None = (
-            session.query(DR).filter(DR.dag_id == dag_id, DR.run_id == 
dag_run_id).one_or_none()
-        )
+        dag_run: DR | None = session.scalar(select(DR).where(DR.dag_id == 
dag_id, DR.run_id == dag_run_id))
         if dag_run is None:
             error_message = f"Dag Run id {dag_run_id} not found in dag 
{dag_id}"
             raise NotFound(error_message)
@@ -486,16 +492,17 @@ def post_clear_task_instances(*, dag_id: str, session: 
Session = NEW_SESSION) ->
             # If we had upstream/downstream etc then also include those!
             task_ids.extend(tid for tid in dag.task_dict if tid != task_id)
     task_instances = dag.clear(dry_run=True, 
dag_bag=get_airflow_app().dag_bag, task_ids=task_ids, **data)
+
     if not dry_run:
         clear_task_instances(
-            task_instances.all(),
+            task_instances,
             session,
             dag=dag,
             dag_run_state=DagRunState.QUEUED if reset_dag_runs else False,
         )
 
     return task_instance_reference_collection_schema.dump(
-        TaskInstanceReferenceCollection(task_instances=task_instances.all())
+        TaskInstanceReferenceCollection(task_instances=task_instances)
     )
 
 
@@ -532,9 +539,11 @@ def post_set_task_instances_state(*, dag_id: str, session: 
Session = NEW_SESSION
     if (
         execution_date
         and (
-            session.query(TI)
-            .filter(TI.task_id == task_id, TI.dag_id == dag_id, 
TI.execution_date == execution_date)
-            .one_or_none()
+            session.scalars(
+                select(TI).where(
+                    TI.task_id == task_id, TI.dag_id == dag_id, 
TI.execution_date == execution_date
+                )
+            ).one_or_none()
         )
         is None
     ):
@@ -653,8 +662,8 @@ def set_task_instance_note(
         raise BadRequest(detail=str(err))
 
     query = (
-        session.query(TI)
-        .filter(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id == 
task_id)
+        select(TI)
+        .where(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id == 
task_id)
         .join(TI.dag_run)
         .outerjoin(
             SlaMiss,
@@ -664,16 +673,16 @@ def set_task_instance_note(
                 SlaMiss.task_id == TI.task_id,
             ),
         )
-        .add_entity(SlaMiss)
+        .add_columns(SlaMiss)
         .options(joinedload(TI.rendered_task_instance_fields))
     )
     if map_index == -1:
-        query = query.filter(or_(TI.map_index == -1, TI.map_index is None))
+        query = query.where(or_(TI.map_index == -1, TI.map_index is None))
     else:
-        query = query.filter(TI.map_index == map_index)
+        query = query.where(TI.map_index == map_index)
 
     try:
-        result = query.one_or_none()
+        result = session.execute(query).one_or_none()
     except MultipleResultsFound:
         raise NotFound(
             "Task instance not found", detail="Task instance is mapped, add 
the map_index value to the URL"
diff --git a/airflow/api_connexion/endpoints/user_endpoint.py 
b/airflow/api_connexion/endpoints/user_endpoint.py
index a9482701a4..2a88fb1b24 100644
--- a/airflow/api_connexion/endpoints/user_endpoint.py
+++ b/airflow/api_connexion/endpoints/user_endpoint.py
@@ -21,7 +21,7 @@ from http import HTTPStatus
 from connexion import NoContent
 from flask import request
 from marshmallow import ValidationError
-from sqlalchemy import asc, desc, func
+from sqlalchemy import asc, desc, func, select
 from werkzeug.security import generate_password_hash
 
 from airflow.api_connexion import security
@@ -55,7 +55,7 @@ def get_users(*, limit: int, order_by: str = "id", offset: 
str | None = None) ->
     """Get users."""
     appbuilder = get_airflow_app().appbuilder
     session = appbuilder.get_session
-    total_entries = session.query(func.count(User.id)).scalar()
+    total_entries = session.execute(select(func.count(User.id))).scalar()
     direction = desc if order_by.startswith("-") else asc
     to_replace = {"user_id": "id"}
     order_param = order_by.strip("-")
@@ -75,8 +75,8 @@ def get_users(*, limit: int, order_by: str = "id", offset: 
str | None = None) ->
             f"the attribute does not exist on the model"
         )
 
-    query = session.query(User)
-    users = query.order_by(direction(getattr(User, 
order_param))).offset(offset).limit(limit).all()
+    query = select(User).order_by(direction(getattr(User, 
order_param))).offset(offset).limit(limit)
+    users = session.scalars(query).all()
 
     return user_collection_schema.dump(UserCollection(users=users, 
total_entries=total_entries))
 
diff --git a/airflow/api_connexion/endpoints/variable_endpoint.py 
b/airflow/api_connexion/endpoints/variable_endpoint.py
index da8f35fcb8..61a1871104 100644
--- a/airflow/api_connexion/endpoints/variable_endpoint.py
+++ b/airflow/api_connexion/endpoints/variable_endpoint.py
@@ -20,7 +20,7 @@ from http import HTTPStatus
 
 from flask import Response
 from marshmallow import ValidationError
-from sqlalchemy import func
+from sqlalchemy import func, select
 from sqlalchemy.orm import Session
 
 from airflow.api_connexion import security
@@ -57,10 +57,10 @@ def delete_variable(*, variable_key: str) -> Response:
 @provide_session
 def get_variable(*, variable_key: str, session: Session = NEW_SESSION) -> 
Response:
     """Get a variable by key."""
-    var = session.query(Variable).filter(Variable.key == variable_key)
-    if not var.count():
+    var = session.scalar(select(Variable).where(Variable.key == 
variable_key).limit(1))
+    if not var:
         raise NotFound("Variable not found")
-    return variable_schema.dump(var.first())
+    return variable_schema.dump(var)
 
 
 @security.requires_access([(permissions.ACTION_CAN_READ, 
permissions.RESOURCE_VARIABLE)])
@@ -74,12 +74,12 @@ def get_variables(
     session: Session = NEW_SESSION,
 ) -> Response:
     """Get all variable values."""
-    total_entries = session.query(func.count(Variable.id)).scalar()
+    total_entries = session.execute(select(func.count(Variable.id))).scalar()
     to_replace = {"value": "val"}
     allowed_filter_attrs = ["value", "key", "id"]
-    query = session.query(Variable)
+    query = select(Variable)
     query = apply_sorting(query, order_by, to_replace, allowed_filter_attrs)
-    variables = query.offset(offset).limit(limit).all()
+    variables = session.scalars(query.offset(offset).limit(limit)).all()
     return variable_collection_schema.dump(
         {
             "variables": variables,
@@ -111,7 +111,7 @@ def patch_variable(
     if data["key"] != variable_key:
         raise BadRequest("Invalid post body", detail="key from request body 
doesn't match uri parameter")
     non_update_fields = ["key"]
-    variable = session.query(Variable).filter_by(key=variable_key).first()
+    variable = 
session.scalar(select(Variable).filter_by(key=variable_key).limit(1))
     if update_mask:
         data = extract_update_mask_data(update_mask, non_update_fields, data)
     for key, val in data.items():
diff --git a/airflow/api_connexion/endpoints/xcom_endpoint.py 
b/airflow/api_connexion/endpoints/xcom_endpoint.py
index 2ab5ec26f5..830cedb51c 100644
--- a/airflow/api_connexion/endpoints/xcom_endpoint.py
+++ b/airflow/api_connexion/endpoints/xcom_endpoint.py
@@ -19,7 +19,7 @@ from __future__ import annotations
 import copy
 
 from flask import g
-from sqlalchemy import and_
+from sqlalchemy import and_, func, select
 from sqlalchemy.orm import Session
 
 from airflow.api_connexion import security
@@ -53,23 +53,23 @@ def get_xcom_entries(
     session: Session = NEW_SESSION,
 ) -> APIResponse:
     """Get all XCom values."""
-    query = session.query(XCom)
+    query = select(XCom)
     if dag_id == "~":
         appbuilder = get_airflow_app().appbuilder
         readable_dag_ids = appbuilder.sm.get_readable_dag_ids(g.user)
-        query = query.filter(XCom.dag_id.in_(readable_dag_ids))
+        query = query.where(XCom.dag_id.in_(readable_dag_ids))
         query = query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.run_id == 
DR.run_id))
     else:
-        query = query.filter(XCom.dag_id == dag_id)
+        query = query.where(XCom.dag_id == dag_id)
         query = query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.run_id == 
DR.run_id))
 
     if task_id != "~":
-        query = query.filter(XCom.task_id == task_id)
+        query = query.where(XCom.task_id == task_id)
     if dag_run_id != "~":
-        query = query.filter(DR.run_id == dag_run_id)
+        query = query.where(DR.run_id == dag_run_id)
     query = query.order_by(DR.execution_date, XCom.task_id, XCom.dag_id, 
XCom.key)
-    total_entries = query.count()
-    query = query.offset(offset).limit(limit)
+    total_entries = 
session.execute(select(func.count()).select_from(query)).scalar()
+    query = session.scalars(query.offset(offset).limit(limit))
     return xcom_collection_schema.dump(XComCollection(xcom_entries=query, 
total_entries=total_entries))
 
 
@@ -93,15 +93,19 @@ def get_xcom_entry(
 ) -> APIResponse:
     """Get an XCom entry."""
     if deserialize:
-        query = session.query(XCom, XCom.value)
+        query = select(XCom, XCom.value)
     else:
-        query = session.query(XCom)
+        query = select(XCom)
 
-    query = query.filter(XCom.dag_id == dag_id, XCom.task_id == task_id, 
XCom.key == xcom_key)
+    query = query.where(XCom.dag_id == dag_id, XCom.task_id == task_id, 
XCom.key == xcom_key)
     query = query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.run_id == 
DR.run_id))
-    query = query.filter(DR.run_id == dag_run_id)
+    query = query.where(DR.run_id == dag_run_id)
+
+    if deserialize:
+        item = session.execute(query).one_or_none()
+    else:
+        item = session.scalars(query).one_or_none()
 
-    item = query.one_or_none()
     if item is None:
         raise NotFound("XCom entry not found")
 
diff --git a/airflow/api_connexion/parameters.py 
b/airflow/api_connexion/parameters.py
index 5f8c8ad360..f4f55cfecd 100644
--- a/airflow/api_connexion/parameters.py
+++ b/airflow/api_connexion/parameters.py
@@ -23,7 +23,7 @@ from typing import Any, Callable, Container, TypeVar, cast
 
 from pendulum.parsing import ParserError
 from sqlalchemy import text
-from sqlalchemy.orm.query import Query
+from sqlalchemy.sql import Select
 
 from airflow.api_connexion.exceptions import BadRequest
 from airflow.configuration import conf
@@ -106,11 +106,11 @@ def format_parameters(params_formatters: dict[str, 
Callable[[Any], Any]]) -> Cal
 
 
 def apply_sorting(
-    query: Query,
+    query: Select,
     order_by: str,
     to_replace: dict[str, str] | None = None,
     allowed_attrs: Container[str] | None = None,
-) -> Query:
+) -> Select:
     """Apply sorting to query."""
     lstriped_orderby = order_by.lstrip("-")
     if allowed_attrs and lstriped_orderby not in allowed_attrs:
diff --git a/airflow/models/abstractoperator.py 
b/airflow/models/abstractoperator.py
index ff4f5c4140..73d9204aab 100644
--- a/airflow/models/abstractoperator.py
+++ b/airflow/models/abstractoperator.py
@@ -22,6 +22,8 @@ import inspect
 from functools import cached_property
 from typing import TYPE_CHECKING, Any, Callable, ClassVar, Collection, 
Iterable, Iterator, Sequence
 
+from sqlalchemy import select
+
 from airflow.compat.functools import cache
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException
@@ -549,17 +551,15 @@ class AbstractOperator(Templater, DAGNode):
             total_length = None
 
         state: TaskInstanceState | None = None
-        unmapped_ti: TaskInstance | None = (
-            session.query(TaskInstance)
-            .filter(
+        unmapped_ti: TaskInstance | None = session.scalars(
+            select(TaskInstance).where(
                 TaskInstance.dag_id == self.dag_id,
                 TaskInstance.task_id == self.task_id,
                 TaskInstance.run_id == run_id,
                 TaskInstance.map_index == -1,
                 or_(TaskInstance.state.in_(State.unfinished), 
TaskInstance.state.is_(None)),
             )
-            .one_or_none()
-        )
+        ).one_or_none()
 
         all_expanded_tis: list[TaskInstance] = []
 
@@ -582,14 +582,14 @@ class AbstractOperator(Templater, DAGNode):
                 unmapped_ti.state = TaskInstanceState.SKIPPED
             else:
                 zero_index_ti_exists = (
-                    session.query(TaskInstance)
-                    .filter(
-                        TaskInstance.dag_id == self.dag_id,
-                        TaskInstance.task_id == self.task_id,
-                        TaskInstance.run_id == run_id,
-                        TaskInstance.map_index == 0,
+                    session.scalar(
+                        select(func.count(TaskInstance.task_id)).where(
+                            TaskInstance.dag_id == self.dag_id,
+                            TaskInstance.task_id == self.task_id,
+                            TaskInstance.run_id == run_id,
+                            TaskInstance.map_index == 0,
+                        )
                     )
-                    .count()
                     > 0
                 )
                 if not zero_index_ti_exists:
@@ -609,14 +609,12 @@ class AbstractOperator(Templater, DAGNode):
             indexes_to_map: Iterable[int] = ()
         else:
             # Only create "missing" ones.
-            current_max_mapping = (
-                session.query(func.max(TaskInstance.map_index))
-                .filter(
+            current_max_mapping = session.scalar(
+                select(func.max(TaskInstance.map_index)).where(
                     TaskInstance.dag_id == self.dag_id,
                     TaskInstance.task_id == self.task_id,
                     TaskInstance.run_id == run_id,
                 )
-                .scalar()
             )
             indexes_to_map = range(current_max_mapping + 1, total_length)
 
@@ -635,13 +633,14 @@ class AbstractOperator(Templater, DAGNode):
 
         # Any (old) task instances with inapplicable indexes (>= the total
         # number we need) are set to "REMOVED".
-        query = session.query(TaskInstance).filter(
+        query = select(TaskInstance).where(
             TaskInstance.dag_id == self.dag_id,
             TaskInstance.task_id == self.task_id,
             TaskInstance.run_id == run_id,
             TaskInstance.map_index >= total_expanded_ti_count,
         )
-        to_update = with_row_locks(query, of=TaskInstance, session=session, 
**skip_locked(session=session))
+        query = with_row_locks(query, of=TaskInstance, session=session, 
**skip_locked(session=session))
+        to_update = session.scalars(query)
         for ti in to_update:
             ti.state = TaskInstanceState.REMOVED
         session.flush()
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 3d0812b62a..12f4206fcf 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -49,6 +49,7 @@ from typing import (
 import attr
 import pendulum
 from dateutil.relativedelta import relativedelta
+from sqlalchemy import select
 from sqlalchemy.orm import Session
 from sqlalchemy.orm.exc import NoResultFound
 
@@ -1256,12 +1257,12 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         Clears the state of task instances associated with the task, following
         the parameters specified.
         """
-        qry = session.query(TaskInstance).filter(TaskInstance.dag_id == 
self.dag_id)
+        qry = select(TaskInstance).where(TaskInstance.dag_id == self.dag_id)
 
         if start_date:
-            qry = qry.filter(TaskInstance.execution_date >= start_date)
+            qry = qry.where(TaskInstance.execution_date >= start_date)
         if end_date:
-            qry = qry.filter(TaskInstance.execution_date <= end_date)
+            qry = qry.where(TaskInstance.execution_date <= end_date)
 
         tasks = [self.task_id]
 
@@ -1271,8 +1272,8 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         if downstream:
             tasks += [t.task_id for t in 
self.get_flat_relatives(upstream=False)]
 
-        qry = qry.filter(TaskInstance.task_id.in_(tasks))
-        results = qry.all()
+        qry = qry.where(TaskInstance.task_id.in_(tasks))
+        results = session.scalars(qry).all()
         count = len(results)
         clear_task_instances(results, session, dag=self.dag)
         session.commit()
@@ -1289,16 +1290,15 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         from airflow.models import DagRun
 
         end_date = end_date or timezone.utcnow()
-        return (
-            session.query(TaskInstance)
+        return session.scalars(
+            select(TaskInstance)
             .join(TaskInstance.dag_run)
-            .filter(TaskInstance.dag_id == self.dag_id)
-            .filter(TaskInstance.task_id == self.task_id)
-            .filter(DagRun.execution_date >= start_date)
-            .filter(DagRun.execution_date <= end_date)
+            .where(TaskInstance.dag_id == self.dag_id)
+            .where(TaskInstance.task_id == self.task_id)
+            .where(DagRun.execution_date >= start_date)
+            .where(DagRun.execution_date <= end_date)
             .order_by(DagRun.execution_date)
-            .all()
-        )
+        ).all()
 
     @provide_session
     def run(
@@ -1327,14 +1327,12 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         for info in self.dag.iter_dagrun_infos_between(start_date, end_date, 
align=False):
             ignore_depends_on_past = info.logical_date == start_date and 
ignore_first_depends_on_past
             try:
-                dag_run = (
-                    session.query(DagRun)
-                    .filter(
+                dag_run = session.scalars(
+                    select(DagRun).where(
                         DagRun.dag_id == self.dag_id,
                         DagRun.execution_date == info.logical_date,
                     )
-                    .one()
-                )
+                ).one()
                 ti = TaskInstance(self, run_id=dag_run.run_id)
             except NoResultFound:
                 # This is _mostly_ only used in tests
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 7e2f23f47b..892ff5cef4 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -55,12 +55,27 @@ import pendulum
 import re2 as re
 from dateutil.relativedelta import relativedelta
 from pendulum.tz.timezone import Timezone
-from sqlalchemy import Boolean, Column, ForeignKey, Index, Integer, String, 
Text, and_, case, func, not_, or_
+from sqlalchemy import (
+    Boolean,
+    Column,
+    ForeignKey,
+    Index,
+    Integer,
+    String,
+    Text,
+    and_,
+    case,
+    func,
+    not_,
+    or_,
+    select,
+    update,
+)
 from sqlalchemy.ext.associationproxy import association_proxy
 from sqlalchemy.orm import backref, joinedload, relationship
 from sqlalchemy.orm.query import Query
 from sqlalchemy.orm.session import Session
-from sqlalchemy.sql import expression
+from sqlalchemy.sql import Select, expression
 
 import airflow.templates
 from airflow import settings, utils
@@ -202,11 +217,11 @@ def get_last_dagrun(dag_id, session, 
include_externally_triggered=False):
     Overridden DagRuns are ignored.
     """
     DR = DagRun
-    query = session.query(DR).filter(DR.dag_id == dag_id)
+    query = select(DR).where(DR.dag_id == dag_id)
     if not include_externally_triggered:
-        query = query.filter(DR.external_trigger == expression.false())
+        query = query.where(DR.external_trigger == expression.false())
     query = query.order_by(DR.execution_date.desc())
-    return query.first()
+    return session.scalar(query.limit(1))
 
 
 def get_dataset_triggered_next_run_info(
@@ -224,31 +239,27 @@ def get_dataset_triggered_next_run_info(
             "ready": x.ready,
             "total": x.total,
         }
-        for x in session.query(
-            DagScheduleDatasetReference.dag_id,
-            # This is a dirty hack to workaround group by requiring an 
aggregate, since grouping by dataset
-            # is not what we want to do here...but it works
-            case((func.count() == 1, func.max(DatasetModel.uri)), 
else_="").label("uri"),
-            func.count().label("total"),
-            func.sum(case((DDRQ.target_dag_id.is_not(None), 1), 
else_=0)).label("ready"),
-        )
-        .join(
-            DDRQ,
-            and_(
-                DDRQ.dataset_id == DagScheduleDatasetReference.dataset_id,
-                DDRQ.target_dag_id == DagScheduleDatasetReference.dag_id,
-            ),
-            isouter=True,
-        )
-        .join(
-            DatasetModel,
-            DatasetModel.id == DagScheduleDatasetReference.dataset_id,
-        )
-        .group_by(
-            DagScheduleDatasetReference.dag_id,
-        )
-        .filter(DagScheduleDatasetReference.dag_id.in_(dag_ids))
-        .all()
+        for x in session.execute(
+            select(
+                DagScheduleDatasetReference.dag_id,
+                # This is a dirty hack to workaround group by requiring an 
aggregate,
+                # since grouping by dataset is not what we want to do 
here...but it works
+                case((func.count() == 1, func.max(DatasetModel.uri)), 
else_="").label("uri"),
+                func.count().label("total"),
+                func.sum(case((DDRQ.target_dag_id.is_not(None), 1), 
else_=0)).label("ready"),
+            )
+            .join(
+                DDRQ,
+                and_(
+                    DDRQ.dataset_id == DagScheduleDatasetReference.dataset_id,
+                    DDRQ.target_dag_id == DagScheduleDatasetReference.dag_id,
+                ),
+                isouter=True,
+            )
+            .join(DatasetModel, DatasetModel.id == 
DagScheduleDatasetReference.dataset_id)
+            .group_by(DagScheduleDatasetReference.dag_id)
+            .where(DagScheduleDatasetReference.dag_id.in_(dag_ids))
+        ).all()
     }
 
 
@@ -1296,11 +1307,13 @@ class DAG(LoggingMixin):
         has been reached.
         """
         TI = TaskInstance
-        qry = session.query(func.count(TI.task_id)).filter(
-            TI.dag_id == self.dag_id,
-            TI.state == State.RUNNING,
+        total_tasks = session.scalar(
+            select(func.count(TI.task_id)).where(
+                TI.dag_id == self.dag_id,
+                TI.state == TaskInstanceState.RUNNING,
+            )
         )
-        return qry.scalar() >= self.max_active_tasks
+        return total_tasks >= self.max_active_tasks
 
     @property
     def concurrency_reached(self):
@@ -1315,12 +1328,12 @@ class DAG(LoggingMixin):
     @provide_session
     def get_is_active(self, session=NEW_SESSION) -> None:
         """Returns a boolean indicating whether this DAG is active."""
-        return session.query(DagModel.is_active).filter(DagModel.dag_id == 
self.dag_id).scalar()
+        return session.scalar(select(DagModel.is_active).where(DagModel.dag_id 
== self.dag_id))
 
     @provide_session
     def get_is_paused(self, session=NEW_SESSION) -> None:
         """Returns a boolean indicating whether this DAG is paused."""
-        return session.query(DagModel.is_paused).filter(DagModel.dag_id == 
self.dag_id).scalar()
+        return session.scalar(select(DagModel.is_paused).where(DagModel.dag_id 
== self.dag_id))
 
     @property
     def is_paused(self):
@@ -1402,19 +1415,18 @@ class DAG(LoggingMixin):
         :param session:
         :return: number greater than 0 for active dag runs
         """
-        # .count() is inefficient
-        query = session.query(func.count()).filter(DagRun.dag_id == 
self.dag_id)
+        query = select(func.count()).where(DagRun.dag_id == self.dag_id)
         if only_running:
-            query = query.filter(DagRun.state == State.RUNNING)
+            query = query.where(DagRun.state == DagRunState.RUNNING)
         else:
-            query = query.filter(DagRun.state.in_({State.RUNNING, 
State.QUEUED}))
+            query = query.where(DagRun.state.in_({DagRunState.RUNNING, 
DagRunState.QUEUED}))
 
         if external_trigger is not None:
-            query = query.filter(
+            query = query.where(
                 DagRun.external_trigger == (expression.true() if 
external_trigger else expression.false())
             )
 
-        return query.scalar()
+        return session.scalar(query)
 
     @provide_session
     def get_dagrun(
@@ -1434,12 +1446,12 @@ class DAG(LoggingMixin):
         """
         if not (execution_date or run_id):
             raise TypeError("You must provide either the execution_date or the 
run_id")
-        query = session.query(DagRun)
+        query = select(DagRun)
         if execution_date:
-            query = query.filter(DagRun.dag_id == self.dag_id, 
DagRun.execution_date == execution_date)
+            query = query.where(DagRun.dag_id == self.dag_id, 
DagRun.execution_date == execution_date)
         if run_id:
-            query = query.filter(DagRun.dag_id == self.dag_id, DagRun.run_id 
== run_id)
-        return query.first()
+            query = query.where(DagRun.dag_id == self.dag_id, DagRun.run_id == 
run_id)
+        return session.scalar(query)
 
     @provide_session
     def get_dagruns_between(self, start_date, end_date, session=NEW_SESSION):
@@ -1451,22 +1463,20 @@ class DAG(LoggingMixin):
         :param session:
         :return: The list of DagRuns found.
         """
-        dagruns = (
-            session.query(DagRun)
-            .filter(
+        dagruns = session.scalars(
+            select(DagRun).where(
                 DagRun.dag_id == self.dag_id,
                 DagRun.execution_date >= start_date,
                 DagRun.execution_date <= end_date,
             )
-            .all()
-        )
+        ).all()
 
         return dagruns
 
     @provide_session
     def get_latest_execution_date(self, session: Session = NEW_SESSION) -> 
pendulum.DateTime | None:
         """Returns the latest date for which at least one dag run exists."""
-        return 
session.query(func.max(DagRun.execution_date)).filter(DagRun.dag_id == 
self.dag_id).scalar()
+        return 
session.scalar(select(func.max(DagRun.execution_date)).where(DagRun.dag_id == 
self.dag_id))
 
     @property
     def latest_execution_date(self):
@@ -1553,16 +1563,15 @@ class DAG(LoggingMixin):
         corresponding to any DagRunType. It can have less if there are
         less than ``num`` scheduled DAG runs before ``base_date``.
         """
-        execution_dates: list[Any] = (
-            session.query(DagRun.execution_date)
-            .filter(
+        execution_dates: list[Any] = session.execute(
+            select(DagRun.execution_date)
+            .where(
                 DagRun.dag_id == self.dag_id,
                 DagRun.execution_date <= base_date,
             )
             .order_by(DagRun.execution_date.desc())
             .limit(num)
-            .all()
-        )
+        ).all()
 
         if len(execution_dates) == 0:
             return self.get_task_instances(start_date=base_date, 
end_date=base_date, session=session)
@@ -1598,7 +1607,7 @@ class DAG(LoggingMixin):
             exclude_task_ids=(),
             session=session,
         )
-        return cast(Query, query).order_by(DagRun.execution_date).all()
+        return session.scalars(cast(Select, 
query).order_by(DagRun.execution_date)).all()
 
     @overload
     def _get_task_instances(
@@ -1671,9 +1680,9 @@ class DAG(LoggingMixin):
 
         # Do we want full objects, or just the primary columns?
         if as_pk_tuple:
-            tis = session.query(TI.dag_id, TI.task_id, TI.run_id, TI.map_index)
+            tis = select(TI.dag_id, TI.task_id, TI.run_id, TI.map_index)
         else:
-            tis = session.query(TaskInstance)
+            tis = select(TaskInstance)
         tis = tis.join(TaskInstance.dag_run)
 
         if include_subdags:
@@ -1683,40 +1692,40 @@ class DAG(LoggingMixin):
                 conditions.append(
                     (TaskInstance.dag_id == dag.dag_id) & 
TaskInstance.task_id.in_(dag.task_ids)
                 )
-            tis = tis.filter(or_(*conditions))
+            tis = tis.where(or_(*conditions))
         elif self.partial:
-            tis = tis.filter(TaskInstance.dag_id == self.dag_id, 
TaskInstance.task_id.in_(self.task_ids))
+            tis = tis.where(TaskInstance.dag_id == self.dag_id, 
TaskInstance.task_id.in_(self.task_ids))
         else:
-            tis = tis.filter(TaskInstance.dag_id == self.dag_id)
+            tis = tis.where(TaskInstance.dag_id == self.dag_id)
         if run_id:
-            tis = tis.filter(TaskInstance.run_id == run_id)
+            tis = tis.where(TaskInstance.run_id == run_id)
         if start_date:
-            tis = tis.filter(DagRun.execution_date >= start_date)
+            tis = tis.where(DagRun.execution_date >= start_date)
         if task_ids is not None:
-            tis = tis.filter(TaskInstance.ti_selector_condition(task_ids))
+            tis = tis.where(TaskInstance.ti_selector_condition(task_ids))
 
         # This allows allow_trigger_in_future config to take affect, rather 
than mandating exec_date <= UTC
         if end_date or not self.allow_future_exec_dates:
             end_date = end_date or timezone.utcnow()
-            tis = tis.filter(DagRun.execution_date <= end_date)
+            tis = tis.where(DagRun.execution_date <= end_date)
 
         if state:
             if isinstance(state, (str, TaskInstanceState)):
-                tis = tis.filter(TaskInstance.state == state)
+                tis = tis.where(TaskInstance.state == state)
             elif len(state) == 1:
-                tis = tis.filter(TaskInstance.state == state[0])
+                tis = tis.where(TaskInstance.state == state[0])
             else:
                 # this is required to deal with NULL values
                 if None in state:
                     if all(x is None for x in state):
-                        tis = tis.filter(TaskInstance.state.is_(None))
+                        tis = tis.where(TaskInstance.state.is_(None))
                     else:
                         not_none_state = [s for s in state if s]
-                        tis = tis.filter(
+                        tis = tis.where(
                             or_(TaskInstance.state.in_(not_none_state), 
TaskInstance.state.is_(None))
                         )
                 else:
-                    tis = tis.filter(TaskInstance.state.in_(state))
+                    tis = tis.where(TaskInstance.state.in_(state))
 
         # Next, get any of them from our parent DAG (if there is one)
         if include_parentdag and self.parent_dag is not None:
@@ -1754,14 +1763,17 @@ class DAG(LoggingMixin):
 
             query = tis
             if as_pk_tuple:
-                condition = TI.filter_for_tis(TaskInstanceKey(*cols) for cols 
in tis.all())
+                all_tis = session.execute(query).all()
+                condition = TI.filter_for_tis(TaskInstanceKey(*cols) for cols 
in all_tis)
                 if condition is not None:
-                    query = session.query(TI).filter(condition)
+                    query = select(TI).where(condition)
 
             if visited_external_tis is None:
                 visited_external_tis = set()
 
-            for ti in query.filter(TI.operator == ExternalTaskMarker.__name__):
+            external_tasks = session.scalars(query.where(TI.operator == 
ExternalTaskMarker.__name__))
+
+            for ti in external_tasks:
                 ti_key = ti.key.primary
                 if ti_key in visited_external_tis:
                     continue
@@ -1784,10 +1796,10 @@ class DAG(LoggingMixin):
                         f"Attempted to clear too many tasks or there may be a 
cyclic dependency."
                     )
                 ti.render_templates()
-                external_tis = (
-                    session.query(TI)
+                external_tis = session.scalars(
+                    select(TI)
                     .join(TI.dag_run)
-                    .filter(
+                    .where(
                         TI.dag_id == task.external_dag_id,
                         TI.task_id == task.external_task_id,
                         DagRun.execution_date == 
pendulum.parse(task.execution_date),
@@ -1830,9 +1842,10 @@ class DAG(LoggingMixin):
         if result or as_pk_tuple:
             # Only execute the `ti` query if we have also collected some other 
results (i.e. subdags etc.)
             if as_pk_tuple:
-                result.update(TaskInstanceKey(**cols._mapping) for cols in 
tis.all())
+                tis_query = session.execute(tis).all()
+                result.update(TaskInstanceKey(**cols._mapping) for cols in 
tis_query)
             else:
-                result.update(ti.key for ti in tis)
+                result.update(ti.key for ti in session.scalars(tis))
 
             if exclude_task_ids is not None:
                 result = {
@@ -1848,13 +1861,13 @@ class DAG(LoggingMixin):
             # We've been asked for objects, lets combine it all back in to a 
result set
             ti_filters = TI.filter_for_tis(result)
             if ti_filters is not None:
-                tis = session.query(TI).filter(ti_filters)
+                tis = select(TI).where(ti_filters)
         elif exclude_task_ids is None:
             pass  # Disable filter if not set.
         elif isinstance(next(iter(exclude_task_ids), None), str):
-            tis = tis.filter(TI.task_id.notin_(exclude_task_ids))
+            tis = tis.where(TI.task_id.notin_(exclude_task_ids))
         else:
-            tis = tis.filter(not_(tuple_in_condition((TI.task_id, 
TI.map_index), exclude_task_ids)))
+            tis = tis.where(not_(tuple_in_condition((TI.task_id, 
TI.map_index), exclude_task_ids)))
 
         return tis
 
@@ -1930,9 +1943,9 @@ class DAG(LoggingMixin):
         )
 
         if execution_date is None:
-            dag_run = (
-                session.query(DagRun).filter(DagRun.run_id == run_id, 
DagRun.dag_id == self.dag_id).one()
-            )  # Raises an error if not found
+            dag_run = session.scalars(
+                select(DagRun).where(DagRun.run_id == run_id, DagRun.dag_id == 
self.dag_id)
+            ).one()  # Raises an error if not found
             resolve_execution_date = dag_run.execution_date
         else:
             resolve_execution_date = execution_date
@@ -1993,9 +2006,9 @@ class DAG(LoggingMixin):
         locked_dag_run_ids: list[int] = []
 
         if execution_date is None:
-            dag_run = (
-                session.query(DagRun).filter(DagRun.run_id == run_id, 
DagRun.dag_id == self.dag_id).one()
-            )  # Raises an error if not found
+            dag_run = session.scalars(
+                select(DagRun).where(DagRun.run_id == run_id, DagRun.dag_id == 
self.dag_id)
+            ).one()  # Raises an error if not found
             resolve_execution_date = dag_run.execution_date
         else:
             resolve_execution_date = execution_date
@@ -2009,16 +2022,16 @@ class DAG(LoggingMixin):
             raise ValueError("TaskGroup {group_id} could not be found")
         tasks_to_set_state = [task for task in task_group.iter_tasks() if 
isinstance(task, BaseOperator)]
         task_ids = [task.task_id for task in task_group.iter_tasks()]
-        dag_runs_query = session.query(DagRun.id).filter(DagRun.dag_id == 
self.dag_id).with_for_update()
+        dag_runs_query = session.query(DagRun.id).where(DagRun.dag_id == 
self.dag_id).with_for_update()
 
         if start_date is None and end_date is None:
-            dag_runs_query = dag_runs_query.filter(DagRun.execution_date == 
start_date)
+            dag_runs_query = dag_runs_query.where(DagRun.execution_date == 
start_date)
         else:
             if start_date is not None:
-                dag_runs_query = dag_runs_query.filter(DagRun.execution_date 
>= start_date)
+                dag_runs_query = dag_runs_query.where(DagRun.execution_date >= 
start_date)
 
             if end_date is not None:
-                dag_runs_query = dag_runs_query.filter(DagRun.execution_date 
<= end_date)
+                dag_runs_query = dag_runs_query.where(DagRun.execution_date <= 
end_date)
 
         locked_dag_run_ids = dag_runs_query.all()
 
@@ -2105,12 +2118,12 @@ class DAG(LoggingMixin):
             stacklevel=3,
         )
         dag_ids = dag_ids or [self.dag_id]
-        query = session.query(DagRun).filter(DagRun.dag_id.in_(dag_ids))
+        query = update(DagRun).where(DagRun.dag_id.in_(dag_ids))
         if start_date:
-            query = query.filter(DagRun.execution_date >= start_date)
+            query = query.where(DagRun.execution_date >= start_date)
         if end_date:
-            query = query.filter(DagRun.execution_date <= end_date)
-        query.update({DagRun.state: state}, synchronize_session="fetch")
+            query = query.where(DagRun.execution_date <= end_date)
+        
session.execute(query.values(state=state).execution_options(synchronize_session="fetch"))
 
     @provide_session
     def clear(
@@ -2196,11 +2209,11 @@ class DAG(LoggingMixin):
         )
 
         if dry_run:
-            return tis
+            return session.scalars(tis).all()
 
-        tis = list(tis)
+        tis = session.scalars(tis).all()
 
-        count = len(tis)
+        count = len(list(tis))
         do_it = True
         if count == 0:
             return 0
@@ -2213,7 +2226,7 @@ class DAG(LoggingMixin):
 
         if do_it:
             clear_task_instances(
-                tis,
+                list(tis),
                 session,
                 dag=self,
                 dag_run_state=dag_run_state,
@@ -2462,10 +2475,10 @@ class DAG(LoggingMixin):
 
     @provide_session
     def pickle(self, session=NEW_SESSION) -> DagPickle:
-        dag = session.query(DagModel).filter(DagModel.dag_id == 
self.dag_id).first()
+        dag = session.scalar(select(DagModel).where(DagModel.dag_id == 
self.dag_id).limit(1))
         dp = None
         if dag and dag.pickle_id:
-            dp = session.query(DagPickle).filter(DagPickle.id == 
dag.pickle_id).first()
+            dp = session.scalar(select(DagPickle).where(DagPickle.id == 
dag.pickle_id).limit(1))
         if not dp or dp.pickle != self:
             dp = DagPickle(dag=self)
             session.add(dp)
@@ -2881,13 +2894,14 @@ class DAG(LoggingMixin):
 
         dag_ids = set(dag_by_ids.keys())
         query = (
-            session.query(DagModel)
+            select(DagModel)
             .options(joinedload(DagModel.tags, innerjoin=False))
-            .filter(DagModel.dag_id.in_(dag_ids))
+            .where(DagModel.dag_id.in_(dag_ids))
             .options(joinedload(DagModel.schedule_dataset_references))
             .options(joinedload(DagModel.task_outlet_dataset_references))
         )
-        orm_dags: list[DagModel] = with_row_locks(query, of=DagModel, 
session=session).all()
+        query = with_row_locks(query, of=DagModel, session=session)
+        orm_dags: list[DagModel] = session.scalars(query).unique().all()
         existing_dags = {orm_dag.dag_id: orm_dag for orm_dag in orm_dags}
         missing_dag_ids = dag_ids.difference(existing_dags)
 
@@ -2903,17 +2917,19 @@ class DAG(LoggingMixin):
 
         # Get the latest dag run for each existing dag as a single query 
(avoid n+1 query)
         most_recent_subq = (
-            session.query(DagRun.dag_id, 
func.max(DagRun.execution_date).label("max_execution_date"))
-            .filter(
+            select(DagRun.dag_id, 
func.max(DagRun.execution_date).label("max_execution_date"))
+            .where(
                 DagRun.dag_id.in_(existing_dags),
                 or_(DagRun.run_type == DagRunType.BACKFILL_JOB, 
DagRun.run_type == DagRunType.SCHEDULED),
             )
             .group_by(DagRun.dag_id)
             .subquery()
         )
-        most_recent_runs_iter = session.query(DagRun).filter(
-            DagRun.dag_id == most_recent_subq.c.dag_id,
-            DagRun.execution_date == most_recent_subq.c.max_execution_date,
+        most_recent_runs_iter = session.scalars(
+            select(DagRun).where(
+                DagRun.dag_id == most_recent_subq.c.dag_id,
+                DagRun.execution_date == most_recent_subq.c.max_execution_date,
+            )
         )
         most_recent_runs = {run.dag_id: run for run in most_recent_runs_iter}
 
@@ -3029,7 +3045,9 @@ class DAG(LoggingMixin):
         # store datasets
         stored_datasets = {}
         for dataset in all_datasets:
-            stored_dataset = 
session.query(DatasetModel).filter(DatasetModel.uri == dataset.uri).first()
+            stored_dataset = session.scalar(
+                select(DatasetModel).where(DatasetModel.uri == 
dataset.uri).limit(1)
+            )
             if stored_dataset:
                 # Some datasets may have been previously unreferenced, and 
therefore orphaned by the
                 # scheduler. But if we're here, then we have found that 
dataset again in our DAGs, which
@@ -3114,7 +3132,7 @@ class DAG(LoggingMixin):
         """
         if len(active_dag_ids) == 0:
             return
-        for dag in 
session.query(DagModel).filter(~DagModel.dag_id.in_(active_dag_ids)).all():
+        for dag in 
session.scalars(select(DagModel).where(~DagModel.dag_id.in_(active_dag_ids))).all():
             dag.is_active = False
             session.merge(dag)
         session.commit()
@@ -3130,10 +3148,8 @@ class DAG(LoggingMixin):
             time
         :return: None
         """
-        for dag in (
-            session.query(DagModel)
-            .filter(DagModel.last_parsed_time < expiration_date, 
DagModel.is_active)
-            .all()
+        for dag in session.scalars(
+            select(DagModel).where(DagModel.last_parsed_time < 
expiration_date, DagModel.is_active)
         ):
             log.info(
                 "Deactivating DAG ID %s since it was last touched by the 
scheduler at %s",
@@ -3157,30 +3173,30 @@ class DAG(LoggingMixin):
         :param states: A list of states to filter by if supplied
         :return: The number of running tasks
         """
-        qry = session.query(func.count(TaskInstance.task_id)).filter(
+        qry = select(func.count(TaskInstance.task_id)).where(
             TaskInstance.dag_id == dag_id,
         )
         if run_id:
-            qry = qry.filter(
+            qry = qry.where(
                 TaskInstance.run_id == run_id,
             )
         if task_ids:
-            qry = qry.filter(
+            qry = qry.where(
                 TaskInstance.task_id.in_(task_ids),
             )
 
         if states:
             if None in states:
                 if all(x is None for x in states):
-                    qry = qry.filter(TaskInstance.state.is_(None))
+                    qry = qry.where(TaskInstance.state.is_(None))
                 else:
                     not_none_states = [state for state in states if state]
-                    qry = qry.filter(
+                    qry = qry.where(
                         or_(TaskInstance.state.in_(not_none_states), 
TaskInstance.state.is_(None))
                     )
             else:
-                qry = qry.filter(TaskInstance.state.in_(states))
-        return qry.scalar()
+                qry = qry.where(TaskInstance.state.in_(states))
+        return session.scalar(qry)
 
     @classmethod
     def get_serialized_fields(cls):
@@ -3301,7 +3317,7 @@ class DagOwnerAttributes(Base):
     @classmethod
     def get_all(cls, session) -> dict[str, dict[str, str]]:
         dag_links: dict = collections.defaultdict(dict)
-        for obj in session.query(cls):
+        for obj in session.scalars(select(cls)):
             dag_links[obj.dag_id].update({obj.owner: obj.link})
         return dag_links
 
@@ -3447,7 +3463,7 @@ class DagModel(Base):
     @classmethod
     @provide_session
     def get_current(cls, dag_id, session=NEW_SESSION):
-        return session.query(cls).filter(cls.dag_id == dag_id).first()
+        return session.scalar(select(cls).where(cls.dag_id == dag_id))
 
     @provide_session
     def get_last_dagrun(self, session=NEW_SESSION, 
include_externally_triggered=False):
@@ -3470,11 +3486,10 @@ class DagModel(Base):
         :param session: ORM Session
         :return: Paused Dag_ids
         """
-        paused_dag_ids = (
-            session.query(DagModel.dag_id)
-            .filter(DagModel.is_paused == expression.true())
-            .filter(DagModel.dag_id.in_(dag_ids))
-            .all()
+        paused_dag_ids = session.execute(
+            select(DagModel.dag_id)
+            .where(DagModel.is_paused == expression.true())
+            .where(DagModel.dag_id.in_(dag_ids))
         )
 
         paused_dag_ids = {paused_dag_id for paused_dag_id, in paused_dag_ids}
@@ -3518,8 +3533,11 @@ class DagModel(Base):
         ]
         if including_subdags:
             filter_query.append(DagModel.root_dag_id == self.dag_id)
-        session.query(DagModel).filter(or_(*filter_query)).update(
-            {DagModel.is_paused: is_paused}, synchronize_session="fetch"
+        session.execute(
+            update(DagModel)
+            .where(or_(*filter_query))
+            .values(is_paused=is_paused)
+            .execution_options(synchronize_session="fetch")
         )
         session.commit()
 
@@ -3538,7 +3556,8 @@ class DagModel(Base):
         :param session: ORM Session
         """
         log.debug("Deactivating DAGs (for which DAG files are deleted) from %s 
table ", cls.__tablename__)
-        for dag_model in session.query(cls).filter(cls.fileloc.is_not(None)):
+        dag_models = 
session.scalars(select(cls).where(cls.fileloc.is_not(None)))
+        for dag_model in dag_models:
             if dag_model.fileloc not in alive_dag_filelocs:
                 dag_model.is_active = False
 
@@ -3556,28 +3575,30 @@ class DagModel(Base):
         # these dag ids are triggered by datasets, and they are ready to go.
         dataset_triggered_dag_info = {
             x.dag_id: (x.first_queued_time, x.last_queued_time)
-            for x in session.query(
-                DagScheduleDatasetReference.dag_id,
-                func.max(DDRQ.created_at).label("last_queued_time"),
-                func.min(DDRQ.created_at).label("first_queued_time"),
+            for x in session.execute(
+                select(
+                    DagScheduleDatasetReference.dag_id,
+                    func.max(DDRQ.created_at).label("last_queued_time"),
+                    func.min(DDRQ.created_at).label("first_queued_time"),
+                )
+                .join(DagScheduleDatasetReference.queue_records, isouter=True)
+                .group_by(DagScheduleDatasetReference.dag_id)
+                .having(func.count() == 
func.sum(case((DDRQ.target_dag_id.is_not(None), 1), else_=0)))
             )
-            .join(DagScheduleDatasetReference.queue_records, isouter=True)
-            .group_by(DagScheduleDatasetReference.dag_id)
-            .having(func.count() == 
func.sum(case((DDRQ.target_dag_id.is_not(None), 1), else_=0)))
-            .all()
         }
         dataset_triggered_dag_ids = set(dataset_triggered_dag_info.keys())
         if dataset_triggered_dag_ids:
             exclusion_list = {
-                x.dag_id
+                x
                 for x in (
-                    session.query(DagModel.dag_id)
-                    .join(DagRun.dag_model)
-                    .filter(DagRun.state.in_((DagRunState.QUEUED, 
DagRunState.RUNNING)))
-                    .filter(DagModel.dag_id.in_(dataset_triggered_dag_ids))
-                    .group_by(DagModel.dag_id)
-                    .having(func.count() >= func.max(DagModel.max_active_runs))
-                    .all()
+                    session.scalars(
+                        select(DagModel.dag_id)
+                        .join(DagRun.dag_model)
+                        .where(DagRun.state.in_((DagRunState.QUEUED, 
DagRunState.RUNNING)))
+                        .where(DagModel.dag_id.in_(dataset_triggered_dag_ids))
+                        .group_by(DagModel.dag_id)
+                        .having(func.count() >= 
func.max(DagModel.max_active_runs))
+                    )
                 )
             }
             if exclusion_list:
@@ -3588,8 +3609,8 @@ class DagModel(Base):
 
         # We limit so that _one_ scheduler doesn't try to do all the creation 
of dag runs
         query = (
-            session.query(cls)
-            .filter(
+            select(cls)
+            .where(
                 cls.is_paused == expression.false(),
                 cls.is_active == expression.true(),
                 cls.has_import_errors == expression.false(),
@@ -3603,7 +3624,7 @@ class DagModel(Base):
         )
 
         return (
-            with_row_locks(query, of=cls, session=session, 
**skip_locked(session=session)),
+            session.scalars(with_row_locks(query, of=cls, session=session, 
**skip_locked(session=session))),
             dataset_triggered_dag_info,
         )
 
@@ -3871,10 +3892,8 @@ def _get_or_create_dagrun(
     :return: The newly created DAG run.
     """
     log.info("dagrun id: %s", dag.dag_id)
-    dr: DagRun = (
-        session.query(DagRun)
-        .filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date == 
execution_date)
-        .first()
+    dr: DagRun = session.scalar(
+        select(DagRun).where(DagRun.dag_id == dag.dag_id, 
DagRun.execution_date == execution_date)
     )
     if dr:
         session.delete(dr)
diff --git a/airflow/models/dagcode.py b/airflow/models/dagcode.py
index 9849f328f2..61b007bf6c 100644
--- a/airflow/models/dagcode.py
+++ b/airflow/models/dagcode.py
@@ -22,7 +22,7 @@ import struct
 from datetime import datetime
 from typing import Collection, Iterable
 
-from sqlalchemy import BigInteger, Column, String, Text, delete
+from sqlalchemy import BigInteger, Column, String, Text, delete, select
 from sqlalchemy.dialects.mysql import MEDIUMTEXT
 from sqlalchemy.orm import Session
 from sqlalchemy.sql.expression import literal
@@ -77,12 +77,11 @@ class DagCode(Base):
         """
         filelocs = set(filelocs)
         filelocs_to_hashes = {fileloc: DagCode.dag_fileloc_hash(fileloc) for 
fileloc in filelocs}
-        existing_orm_dag_codes = (
-            session.query(DagCode)
+        existing_orm_dag_codes = session.scalars(
+            select(DagCode)
             .filter(DagCode.fileloc_hash.in_(filelocs_to_hashes.values()))
             .with_for_update(of=DagCode)
-            .all()
-        )
+        ).all()
 
         if existing_orm_dag_codes:
             existing_orm_dag_codes_map = {
@@ -151,7 +150,10 @@ class DagCode(Base):
         :param session: ORM Session
         """
         fileloc_hash = cls.dag_fileloc_hash(fileloc)
-        return session.query(literal(True)).filter(cls.fileloc_hash == 
fileloc_hash).one_or_none() is not None
+        return (
+            session.scalars(select(literal(True)).where(cls.fileloc_hash == 
fileloc_hash)).one_or_none()
+            is not None
+        )
 
     @classmethod
     def get_code_by_fileloc(cls, fileloc: str) -> str:
@@ -179,7 +181,7 @@ class DagCode(Base):
     @classmethod
     @provide_session
     def _get_code_from_db(cls, fileloc, session: Session = NEW_SESSION) -> str:
-        dag_code = session.query(cls).filter(cls.fileloc_hash == 
cls.dag_fileloc_hash(fileloc)).first()
+        dag_code = session.scalar(select(cls).where(cls.fileloc_hash == 
cls.dag_fileloc_hash(fileloc)))
         if not dag_code:
             raise DagCodeNotFound()
         else:
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 538cceede8..dab89c12cd 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -40,6 +40,7 @@ from sqlalchemy import (
     func,
     or_,
     text,
+    update,
 )
 from sqlalchemy.exc import IntegrityError
 from sqlalchemy.ext.associationproxy import association_proxy
@@ -270,7 +271,9 @@ class DagRun(Base, LoggingMixin):
 
         :param session: database session
         """
-        dr = session.query(DagRun).filter(DagRun.dag_id == self.dag_id, 
DagRun.run_id == self.run_id).one()
+        dr = session.scalars(
+            select(DagRun).where(DagRun.dag_id == self.dag_id, DagRun.run_id 
== self.run_id)
+        ).one()
         self.id = dr.id
         self.state = dr.state
 
@@ -283,17 +286,17 @@ class DagRun(Base, LoggingMixin):
         session: Session = NEW_SESSION,
     ) -> dict[str, int]:
         """Get the number of active dag runs for each dag."""
-        query = session.query(cls.dag_id, func.count("*"))
+        query = select(cls.dag_id, func.count("*"))
         if dag_ids is not None:
             # 'set' called to avoid duplicate dag_ids, but converted back to 
'list'
             # because SQLAlchemy doesn't accept a set here.
-            query = query.filter(cls.dag_id.in_(set(dag_ids)))
+            query = query.where(cls.dag_id.in_(set(dag_ids)))
         if only_running:
-            query = query.filter(cls.state == State.RUNNING)
+            query = query.where(cls.state == State.RUNNING)
         else:
-            query = query.filter(cls.state.in_([State.RUNNING, State.QUEUED]))
+            query = query.where(cls.state.in_([State.RUNNING, State.QUEUED]))
         query = query.group_by(cls.dag_id)
-        return {dag_id: count for dag_id, count in query.all()}
+        return {dag_id: count for dag_id, count in session.execute(query)}
 
     @classmethod
     def next_dagruns_to_examine(
@@ -317,22 +320,22 @@ class DagRun(Base, LoggingMixin):
 
         # TODO: Bake this query, it is run _A lot_
         query = (
-            session.query(cls)
+            select(cls)
             .with_hint(cls, "USE INDEX (idx_dag_run_running_dags)", 
dialect_name="mysql")
-            .filter(cls.state == state, cls.run_type != 
DagRunType.BACKFILL_JOB)
+            .where(cls.state == state, cls.run_type != DagRunType.BACKFILL_JOB)
             .join(DagModel, DagModel.dag_id == cls.dag_id)
-            .filter(DagModel.is_paused == false(), DagModel.is_active == 
true())
+            .where(DagModel.is_paused == false(), DagModel.is_active == true())
         )
         if state == State.QUEUED:
             # For dag runs in the queued state, we check if they have reached 
the max_active_runs limit
             # and if so we drop them
             running_drs = (
-                session.query(DagRun.dag_id, 
func.count(DagRun.state).label("num_running"))
-                .filter(DagRun.state == DagRunState.RUNNING)
+                select(DagRun.dag_id, 
func.count(DagRun.state).label("num_running"))
+                .where(DagRun.state == DagRunState.RUNNING)
                 .group_by(DagRun.dag_id)
                 .subquery()
             )
-            query = query.outerjoin(running_drs, running_drs.c.dag_id == 
DagRun.dag_id).filter(
+            query = query.outerjoin(running_drs, running_drs.c.dag_id == 
DagRun.dag_id).where(
                 func.coalesce(running_drs.c.num_running, 0) < 
DagModel.max_active_runs
             )
         query = query.order_by(
@@ -341,10 +344,10 @@ class DagRun(Base, LoggingMixin):
         )
 
         if not settings.ALLOW_FUTURE_EXEC_DATES:
-            query = query.filter(DagRun.execution_date <= func.now())
+            query = query.where(DagRun.execution_date <= func.now())
 
-        return with_row_locks(
-            query.limit(max_number), of=cls, session=session, 
**skip_locked(session=session)
+        return session.scalars(
+            with_row_locks(query.limit(max_number), of=cls, session=session, 
**skip_locked(session=session))
         )
 
     @classmethod
@@ -377,35 +380,35 @@ class DagRun(Base, LoggingMixin):
         :param execution_start_date: dag run that was executed from this date
         :param execution_end_date: dag run that was executed until this date
         """
-        qry = session.query(cls)
+        qry = select(cls)
         dag_ids = [dag_id] if isinstance(dag_id, str) else dag_id
         if dag_ids:
-            qry = qry.filter(cls.dag_id.in_(dag_ids))
+            qry = qry.where(cls.dag_id.in_(dag_ids))
 
         if is_container(run_id):
-            qry = qry.filter(cls.run_id.in_(run_id))
+            qry = qry.where(cls.run_id.in_(run_id))
         elif run_id is not None:
-            qry = qry.filter(cls.run_id == run_id)
+            qry = qry.where(cls.run_id == run_id)
         if is_container(execution_date):
-            qry = qry.filter(cls.execution_date.in_(execution_date))
+            qry = qry.where(cls.execution_date.in_(execution_date))
         elif execution_date is not None:
-            qry = qry.filter(cls.execution_date == execution_date)
+            qry = qry.where(cls.execution_date == execution_date)
         if execution_start_date and execution_end_date:
-            qry = qry.filter(cls.execution_date.between(execution_start_date, 
execution_end_date))
+            qry = qry.where(cls.execution_date.between(execution_start_date, 
execution_end_date))
         elif execution_start_date:
-            qry = qry.filter(cls.execution_date >= execution_start_date)
+            qry = qry.where(cls.execution_date >= execution_start_date)
         elif execution_end_date:
-            qry = qry.filter(cls.execution_date <= execution_end_date)
+            qry = qry.where(cls.execution_date <= execution_end_date)
         if state:
-            qry = qry.filter(cls.state == state)
+            qry = qry.where(cls.state == state)
         if external_trigger is not None:
-            qry = qry.filter(cls.external_trigger == external_trigger)
+            qry = qry.where(cls.external_trigger == external_trigger)
         if run_type:
-            qry = qry.filter(cls.run_type == run_type)
+            qry = qry.where(cls.run_type == run_type)
         if no_backfills:
-            qry = qry.filter(cls.run_type != DagRunType.BACKFILL_JOB)
+            qry = qry.where(cls.run_type != DagRunType.BACKFILL_JOB)
 
-        return qry.order_by(cls.execution_date).all()
+        return session.scalars(qry.order_by(cls.execution_date)).all()
 
     @classmethod
     @provide_session
@@ -426,14 +429,12 @@ class DagRun(Base, LoggingMixin):
         :param execution_date: the execution date
         :param session: database session
         """
-        return (
-            session.query(cls)
-            .filter(
+        return session.scalars(
+            select(cls).where(
                 cls.dag_id == dag_id,
                 or_(cls.run_id == run_id, cls.execution_date == 
execution_date),
             )
-            .one_or_none()
-        )
+        ).one_or_none()
 
     @staticmethod
     def generate_run_id(run_type: DagRunType, execution_date: datetime) -> str:
@@ -449,9 +450,9 @@ class DagRun(Base, LoggingMixin):
     ) -> list[TI]:
         """Returns the task instances for this dag run."""
         tis = (
-            session.query(TI)
+            select(TI)
             .options(joinedload(TI.dag_run))
-            .filter(
+            .where(
                 TI.dag_id == self.dag_id,
                 TI.run_id == self.run_id,
             )
@@ -459,21 +460,21 @@ class DagRun(Base, LoggingMixin):
 
         if state:
             if isinstance(state, str):
-                tis = tis.filter(TI.state == state)
+                tis = tis.where(TI.state == state)
             else:
                 # this is required to deal with NULL values
                 if State.NONE in state:
                     if all(x is None for x in state):
-                        tis = tis.filter(TI.state.is_(None))
+                        tis = tis.where(TI.state.is_(None))
                     else:
                         not_none_state = [s for s in state if s]
-                        tis = tis.filter(or_(TI.state.in_(not_none_state), 
TI.state.is_(None)))
+                        tis = tis.where(or_(TI.state.in_(not_none_state), 
TI.state.is_(None)))
                 else:
-                    tis = tis.filter(TI.state.in_(state))
+                    tis = tis.where(TI.state.in_(state))
 
         if self.dag and self.dag.partial:
-            tis = tis.filter(TI.task_id.in_(self.dag.task_ids))
-        return tis.all()
+            tis = tis.where(TI.task_id.in_(self.dag.task_ids))
+        return session.scalars(tis).all()
 
     @provide_session
     def get_task_instance(
@@ -489,11 +490,9 @@ class DagRun(Base, LoggingMixin):
         :param task_id: the task id
         :param session: Sqlalchemy ORM Session
         """
-        return (
-            session.query(TI)
-            .filter_by(dag_id=self.dag_id, run_id=self.run_id, 
task_id=task_id, map_index=map_index)
-            .one_or_none()
-        )
+        return session.scalars(
+            select(TI).filter_by(dag_id=self.dag_id, run_id=self.run_id, 
task_id=task_id, map_index=map_index)
+        ).one_or_none()
 
     def get_dag(self) -> DAG:
         """
@@ -517,20 +516,19 @@ class DagRun(Base, LoggingMixin):
         ]
         if state is not None:
             filters.append(DagRun.state == state)
-        return 
session.query(DagRun).filter(*filters).order_by(DagRun.execution_date.desc()).first()
+        return 
session.scalar(select(DagRun).where(*filters).order_by(DagRun.execution_date.desc()))
 
     @provide_session
     def get_previous_scheduled_dagrun(self, session: Session = NEW_SESSION) -> 
DagRun | None:
         """The previous, SCHEDULED DagRun, if there is one."""
-        return (
-            session.query(DagRun)
-            .filter(
+        return session.scalar(
+            select(DagRun)
+            .where(
                 DagRun.dag_id == self.dag_id,
                 DagRun.execution_date < self.execution_date,
                 DagRun.run_type != DagRunType.MANUAL,
             )
             .order_by(DagRun.execution_date.desc())
-            .first()
         )
 
     def _tis_for_dagrun_state(self, *, dag, tis):
@@ -867,7 +865,7 @@ class DagRun(Base, LoggingMixin):
         # Check if any ti changed state
         tis_filter = TI.filter_for_tis(old_states)
         if tis_filter is not None:
-            fresh_tis = session.query(TI).filter(tis_filter).all()
+            fresh_tis = session.scalars(select(TI).where(tis_filter)).all()
             changed_tis = any(ti.state != old_states[ti.key] for ti in 
fresh_tis)
 
         return ready_tis, changed_tis, expansion_happened
@@ -1219,21 +1217,27 @@ class DagRun(Base, LoggingMixin):
         except NotFullyPopulated:
             return  # Upstreams not ready, don't need to revise this yet.
 
-        query = session.query(TI.map_index).filter(
-            TI.dag_id == self.dag_id,
-            TI.task_id == task.task_id,
-            TI.run_id == self.run_id,
+        query = session.scalars(
+            select(TI.map_index).where(
+                TI.dag_id == self.dag_id,
+                TI.task_id == task.task_id,
+                TI.run_id == self.run_id,
+            )
         )
-        existing_indexes = {i for (i,) in query}
+        existing_indexes = {i for i in query}
 
         removed_indexes = existing_indexes.difference(range(total_length))
         if removed_indexes:
-            session.query(TI).filter(
-                TI.dag_id == self.dag_id,
-                TI.task_id == task.task_id,
-                TI.run_id == self.run_id,
-                TI.map_index.in_(removed_indexes),
-            ).update({TI.state: TaskInstanceState.REMOVED})
+            session.execute(
+                update(TI)
+                .where(
+                    TI.dag_id == self.dag_id,
+                    TI.task_id == task.task_id,
+                    TI.run_id == self.run_id,
+                    TI.map_index.in_(removed_indexes),
+                )
+                .values(state=TaskInstanceState.REMOVED)
+            )
             session.flush()
 
         for index in range(total_length):
@@ -1264,14 +1268,12 @@ class DagRun(Base, LoggingMixin):
             RemovedInAirflow3Warning,
             stacklevel=2,
         )
-        return (
-            session.query(DagRun)
-            .filter(
+        return session.scalar(
+            select(DagRun).where(
                 DagRun.dag_id == dag_id,
                 DagRun.external_trigger == False,  # noqa
                 DagRun.execution_date == execution_date,
             )
-            .first()
         )
 
     @property
@@ -1283,18 +1285,16 @@ class DagRun(Base, LoggingMixin):
     def get_latest_runs(cls, session: Session = NEW_SESSION) -> list[DagRun]:
         """Returns the latest DagRun for each DAG."""
         subquery = (
-            session.query(cls.dag_id, 
func.max(cls.execution_date).label("execution_date"))
+            select(cls.dag_id, 
func.max(cls.execution_date).label("execution_date"))
             .group_by(cls.dag_id)
             .subquery()
         )
-        return (
-            session.query(cls)
-            .join(
+        return session.scalars(
+            select(cls).join(
                 subquery,
                 and_(cls.dag_id == subquery.c.dag_id, cls.execution_date == 
subquery.c.execution_date),
             )
-            .all()
-        )
+        ).all()
 
     @provide_session
     def schedule_tis(
@@ -1335,44 +1335,45 @@ class DagRun(Base, LoggingMixin):
                 schedulable_ti_ids, max_tis_per_query or 
len(schedulable_ti_ids)
             )
             for schedulable_ti_ids_chunk in schedulable_ti_ids_chunks:
-                count += (
-                    session.query(TI)
-                    .filter(
+                count += session.execute(
+                    update(TI)
+                    .where(
                         TI.dag_id == self.dag_id,
                         TI.run_id == self.run_id,
                         tuple_in_condition((TI.task_id, TI.map_index), 
schedulable_ti_ids_chunk),
                     )
-                    .update({TI.state: State.SCHEDULED}, 
synchronize_session=False)
-                )
+                    .values(state=TaskInstanceState.SCHEDULED)
+                    .execution_options(synchronize_session=False)
+                ).rowcount
 
         # Tasks using EmptyOperator should not be executed, mark them as 
success
         if dummy_ti_ids:
             dummy_ti_ids_chunks = chunks(dummy_ti_ids, max_tis_per_query or 
len(dummy_ti_ids))
             for dummy_ti_ids_chunk in dummy_ti_ids_chunks:
-                count += (
-                    session.query(TI)
-                    .filter(
+                count += session.execute(
+                    update(TI)
+                    .where(
                         TI.dag_id == self.dag_id,
                         TI.run_id == self.run_id,
                         TI.task_id.in_(dummy_ti_ids_chunk),
                     )
-                    .update(
-                        {
-                            TI.state: State.SUCCESS,
-                            TI.start_date: timezone.utcnow(),
-                            TI.end_date: timezone.utcnow(),
-                            TI.duration: 0,
-                        },
+                    .values(
+                        state=TaskInstanceState.SUCCESS,
+                        start_date=timezone.utcnow(),
+                        end_date=timezone.utcnow(),
+                        duration=0,
+                    )
+                    .execution_options(
                         synchronize_session=False,
                     )
-                )
+                ).rowcount
 
         return count
 
     @provide_session
     def get_log_template(self, *, session: Session = NEW_SESSION) -> 
LogTemplate:
         if self.log_template_id is None:  # DagRun created before LogTemplate 
introduction.
-            template = 
session.query(LogTemplate).order_by(LogTemplate.id).first()
+            template = 
session.scalar(select(LogTemplate).order_by(LogTemplate.id))
         else:
             template = session.get(LogTemplate, self.log_template_id)
         if template is None:
diff --git a/airflow/models/dagwarning.py b/airflow/models/dagwarning.py
index ddc5441321..2690ba8ff8 100644
--- a/airflow/models/dagwarning.py
+++ b/airflow/models/dagwarning.py
@@ -19,7 +19,7 @@ from __future__ import annotations
 
 from enum import Enum
 
-from sqlalchemy import Column, ForeignKeyConstraint, String, Text, false
+from sqlalchemy import Column, ForeignKeyConstraint, String, Text, delete, 
false, select
 from sqlalchemy.orm import Session
 
 from airflow.api_internal.internal_api_call import internal_api_call
@@ -83,11 +83,12 @@ class DagWarning(Base):
         from airflow.models.dag import DagModel
 
         if session.get_bind().dialect.name == "sqlite":
-            dag_ids = session.query(DagModel.dag_id).filter(DagModel.is_active 
== false())
-            query = session.query(cls).filter(cls.dag_id.in_(dag_ids))
+            dag_ids_stmt = select(DagModel.dag_id).where(DagModel.is_active == 
false())
+            query = 
delete(cls).where(cls.dag_id.in_(dag_ids_stmt.scalar_subquery()))
         else:
-            query = session.query(cls).filter(cls.dag_id == DagModel.dag_id, 
DagModel.is_active == false())
-        query.delete(synchronize_session=False)
+            query = delete(cls).where(cls.dag_id == DagModel.dag_id, 
DagModel.is_active == false())
+
+        session.execute(query.execution_options(synchronize_session=False))
         session.commit()
 
 
diff --git a/airflow/models/pool.py b/airflow/models/pool.py
index d1766d4a0a..60f92506f6 100644
--- a/airflow/models/pool.py
+++ b/airflow/models/pool.py
@@ -17,9 +17,9 @@
 # under the License.
 from __future__ import annotations
 
-from typing import Any, Iterable
+from typing import Any
 
-from sqlalchemy import Column, Integer, String, Text, func
+from sqlalchemy import Column, Integer, String, Text, func, select
 from sqlalchemy.orm.session import Session
 
 from airflow.exceptions import AirflowException, PoolNotFound
@@ -60,7 +60,7 @@ class Pool(Base):
     @provide_session
     def get_pools(session: Session = NEW_SESSION) -> list[Pool]:
         """Get all pools."""
-        return session.query(Pool).all()
+        return session.scalars(select(Pool)).all()
 
     @staticmethod
     @provide_session
@@ -72,7 +72,7 @@ class Pool(Base):
         :param session: SQLAlchemy ORM Session
         :return: the pool object
         """
-        return session.query(Pool).filter(Pool.pool == pool_name).first()
+        return session.scalar(select(Pool).where(Pool.pool == pool_name))
 
     @staticmethod
     @provide_session
@@ -96,9 +96,9 @@ class Pool(Base):
         :return: True if id is default_pool, otherwise False
         """
         return (
-            session.query(func.count(Pool.id))
-            .filter(Pool.id == id, Pool.pool == Pool.DEFAULT_POOL_NAME)
-            .scalar()
+            session.scalar(
+                select(func.count(Pool.id)).where(Pool.id == id, Pool.pool == 
Pool.DEFAULT_POOL_NAME)
+            )
             > 0
         )
 
@@ -114,7 +114,7 @@ class Pool(Base):
         if not name:
             raise ValueError("Pool name must not be empty")
 
-        pool = session.query(Pool).filter_by(pool=name).one_or_none()
+        pool = session.scalar(select(Pool).filter_by(pool=name))
         if pool is None:
             pool = Pool(pool=name, slots=slots, description=description)
             session.add(pool)
@@ -132,7 +132,7 @@ class Pool(Base):
         if name == Pool.DEFAULT_POOL_NAME:
             raise AirflowException(f"{Pool.DEFAULT_POOL_NAME} cannot be 
deleted")
 
-        pool = session.query(Pool).filter_by(pool=name).first()
+        pool = session.scalar(select(Pool).filter_by(pool=name))
         if pool is None:
             raise PoolNotFound(f"Pool '{name}' doesn't exist")
 
@@ -162,22 +162,22 @@ class Pool(Base):
 
         pools: dict[str, PoolStats] = {}
 
-        query = session.query(Pool.pool, Pool.slots)
+        query = select(Pool.pool, Pool.slots)
 
         if lock_rows:
             query = with_row_locks(query, session=session, **nowait(session))
 
-        pool_rows: Iterable[tuple[str, int]] = query.all()
+        pool_rows = session.execute(query)
         for (pool_name, total_slots) in pool_rows:
             if total_slots == -1:
                 total_slots = float("inf")  # type: ignore
             pools[pool_name] = PoolStats(total=total_slots, running=0, 
queued=0, open=0)
 
-        state_count_by_pool = (
-            session.query(TaskInstance.pool, TaskInstance.state, 
func.sum(TaskInstance.pool_slots))
+        state_count_by_pool = session.execute(
+            select(TaskInstance.pool, TaskInstance.state, 
func.sum(TaskInstance.pool_slots))
             .filter(TaskInstance.state.in_(list(EXECUTION_STATES)))
             .group_by(TaskInstance.pool, TaskInstance.state)
-        ).all()
+        )
 
         # calculate queued and running metrics
         for (pool_name, state, count) in state_count_by_pool:
@@ -225,10 +225,11 @@ class Pool(Base):
         from airflow.models.taskinstance import TaskInstance  # Avoid circular 
import
 
         return int(
-            session.query(func.sum(TaskInstance.pool_slots))
-            .filter(TaskInstance.pool == self.pool)
-            .filter(TaskInstance.state.in_(EXECUTION_STATES))
-            .scalar()
+            session.scalar(
+                select(func.sum(TaskInstance.pool_slots))
+                .filter(TaskInstance.pool == self.pool)
+                .filter(TaskInstance.state.in_(EXECUTION_STATES))
+            )
             or 0
         )
 
@@ -243,10 +244,11 @@ class Pool(Base):
         from airflow.models.taskinstance import TaskInstance  # Avoid circular 
import
 
         return int(
-            session.query(func.sum(TaskInstance.pool_slots))
-            .filter(TaskInstance.pool == self.pool)
-            .filter(TaskInstance.state == State.RUNNING)
-            .scalar()
+            session.scalar(
+                select(func.sum(TaskInstance.pool_slots))
+                .filter(TaskInstance.pool == self.pool)
+                .filter(TaskInstance.state == State.RUNNING)
+            )
             or 0
         )
 
@@ -261,10 +263,11 @@ class Pool(Base):
         from airflow.models.taskinstance import TaskInstance  # Avoid circular 
import
 
         return int(
-            session.query(func.sum(TaskInstance.pool_slots))
-            .filter(TaskInstance.pool == self.pool)
-            .filter(TaskInstance.state == State.QUEUED)
-            .scalar()
+            session.scalar(
+                select(func.sum(TaskInstance.pool_slots))
+                .filter(TaskInstance.pool == self.pool)
+                .filter(TaskInstance.state == State.QUEUED)
+            )
             or 0
         )
 
@@ -279,10 +282,11 @@ class Pool(Base):
         from airflow.models.taskinstance import TaskInstance  # Avoid circular 
import
 
         return int(
-            session.query(func.sum(TaskInstance.pool_slots))
-            .filter(TaskInstance.pool == self.pool)
-            .filter(TaskInstance.state == State.SCHEDULED)
-            .scalar()
+            session.scalar(
+                select(func.sum(TaskInstance.pool_slots))
+                .filter(TaskInstance.pool == self.pool)
+                .filter(TaskInstance.state == State.SCHEDULED)
+            )
             or 0
         )
 
diff --git 
a/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py 
b/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py
index 4903a88e56..edba1ffd9c 100644
--- a/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py
+++ b/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py
@@ -385,6 +385,7 @@ class 
TestGetMappedTaskInstances(TestMappedTaskInstanceEndpoint):
         )
         assert response.status_code == 200
         assert response.json["total_entries"] == 0
+        assert response.json["task_instances"] == []
 
     @provide_session
     def test_mapped_task_instances_with_state(self, one_task_with_mapped_tis, 
session):
@@ -402,6 +403,7 @@ class 
TestGetMappedTaskInstances(TestMappedTaskInstanceEndpoint):
         )
         assert response.status_code == 200
         assert response.json["total_entries"] == 0
+        assert response.json["task_instances"] == []
 
     @provide_session
     def test_mapped_task_instances_with_pool(self, one_task_with_mapped_tis, 
session):
@@ -420,6 +422,7 @@ class 
TestGetMappedTaskInstances(TestMappedTaskInstanceEndpoint):
         )
         assert response.status_code == 200
         assert response.json["total_entries"] == 0
+        assert response.json["task_instances"] == []
 
     @provide_session
     def test_mapped_task_instances_with_queue(self, one_task_with_mapped_tis, 
session):
@@ -437,6 +440,7 @@ class 
TestGetMappedTaskInstances(TestMappedTaskInstanceEndpoint):
         )
         assert response.status_code == 200
         assert response.json["total_entries"] == 0
+        assert response.json["task_instances"] == []
 
     @provide_session
     def test_mapped_task_instances_with_zero_mapped(self, 
one_task_with_zero_mapped_tis, session):
@@ -446,4 +450,4 @@ class 
TestGetMappedTaskInstances(TestMappedTaskInstanceEndpoint):
         )
         assert response.status_code == 200
         assert response.json["total_entries"] == 0
-        assert len(response.json["task_instances"]) == 0
+        assert response.json["task_instances"] == []
diff --git a/tests/models/test_dagwarning.py b/tests/models/test_dagwarning.py
index 7ae2962b1f..06b14b56ea 100644
--- a/tests/models/test_dagwarning.py
+++ b/tests/models/test_dagwarning.py
@@ -17,6 +17,7 @@
 
 from __future__ import annotations
 
+from unittest import mock
 from unittest.mock import MagicMock
 
 from sqlalchemy.exc import OperationalError
@@ -52,18 +53,18 @@ class TestDagWarning:
         assert len(remaining_dag_warnings) == 1
         assert remaining_dag_warnings[0].dag_id == "dag_2"
 
-    def test_retry_purge_inactive_dag_warnings(self):
+    @mock.patch("airflow.models.dagwarning.delete")
+    def test_retry_purge_inactive_dag_warnings(self, delete_mock):
         """
         Test that the purge_inactive_dag_warnings method calls the delete 
method twice
         if the query throws an operationalError on the first call and works on 
the second attempt
         """
         self.session_mock = MagicMock()
-        self.delete_mock = MagicMock()
-        self.session_mock.query.return_value.filter.return_value.delete = 
self.delete_mock
 
-        self.delete_mock.side_effect = [OperationalError(None, None, "database 
timeout"), None]
+        self.session_mock.execute.side_effect = [OperationalError(None, None, 
"database timeout"), None]
 
         DagWarning.purge_inactive_dag_warnings(self.session_mock)
 
         # Assert that the delete method was called twice
-        assert self.delete_mock.call_count == 2
+        assert delete_mock.call_count == 2
+        assert self.session_mock.execute.call_count == 2

Reply via email to