This is an automated email from the ASF dual-hosted git repository.

phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 4c4981d1ad Refactor Sqlalchemy queries to 2.0 style (Part 7) (#32883)
4c4981d1ad is described below

commit 4c4981d1adf2bd8b28ffa7e6ed57162abb8feb8f
Author: Phani Kumar <[email protected]>
AuthorDate: Mon Aug 21 19:04:38 2023 +0530

    Refactor Sqlalchemy queries to 2.0 style (Part 7) (#32883)
---
 airflow/api/common/experimental/pool.py            |  2 +-
 airflow/auth/managers/fab/models/__init__.py       | 14 +++--
 airflow/dag_processing/processor.py                | 45 ++++++--------
 airflow/datasets/manager.py                        |  4 +-
 airflow/models/dag.py                              |  3 +-
 airflow/models/skipmixin.py                        | 51 ++++++++--------
 airflow/models/variable.py                         |  5 +-
 airflow/operators/subdag.py                        | 13 +++--
 .../celery/executors/celery_executor_utils.py      |  3 +-
 .../kubernetes/executors/kubernetes_executor.py    | 25 ++++----
 airflow/utils/scheduler_health.py                  |  8 ++-
 airflow/utils/sqlalchemy.py                        |  2 +-
 airflow/www/security.py                            | 68 +++++++++++-----------
 tests/datasets/test_manager.py                     |  2 +-
 .../integration/executors/test_celery_executor.py  |  4 +-
 15 files changed, 126 insertions(+), 123 deletions(-)

diff --git a/airflow/api/common/experimental/pool.py 
b/airflow/api/common/experimental/pool.py
index 34e35cd435..1134bda989 100644
--- a/airflow/api/common/experimental/pool.py
+++ b/airflow/api/common/experimental/pool.py
@@ -45,7 +45,7 @@ def get_pool(name, session: Session = NEW_SESSION):
 @provide_session
 def get_pools(session: Session = NEW_SESSION):
     """Get all pools."""
-    return session.query(Pool).all()
+    return session.scalars(select(Pool)).all()
 
 
 @deprecated(reason="Use Pool.create_pool() instead", version="2.2.4")
diff --git a/airflow/auth/managers/fab/models/__init__.py 
b/airflow/auth/managers/fab/models/__init__.py
index 0bc26adb7e..28de7840ec 100644
--- a/airflow/auth/managers/fab/models/__init__.py
+++ b/airflow/auth/managers/fab/models/__init__.py
@@ -37,6 +37,7 @@ from sqlalchemy import (
     UniqueConstraint,
     event,
     func,
+    select,
 )
 from sqlalchemy.orm import backref, declared_attr, relationship
 
@@ -210,12 +211,13 @@ class User(Model, BaseUser):
             if current_app:
                 sm = current_app.appbuilder.sm
                 self._perms: set[tuple[str, str]] = set(
-                    sm.get_session.query(sm.action_model.name, 
sm.resource_model.name)
-                    .join(sm.permission_model.action)
-                    .join(sm.permission_model.resource)
-                    .join(sm.permission_model.role)
-                    .filter(sm.role_model.user.contains(self))
-                    .all()
+                    sm.get_session.execute(
+                        select(sm.action_model.name, sm.resource_model.name)
+                        .join(sm.permission_model.action)
+                        .join(sm.permission_model.resource)
+                        .join(sm.permission_model.role)
+                        .where(sm.role_model.user.contains(self))
+                    )
                 )
             else:
                 self._perms = {
diff --git a/airflow/dag_processing/processor.py 
b/airflow/dag_processing/processor.py
index 162fc5889c..64858be94e 100644
--- a/airflow/dag_processing/processor.py
+++ b/airflow/dag_processing/processor.py
@@ -30,7 +30,7 @@ from multiprocessing.connection import Connection as 
MultiprocessingConnection
 from typing import TYPE_CHECKING, Iterable, Iterator
 
 from setproctitle import setproctitle
-from sqlalchemy import delete, exc, func, or_
+from sqlalchemy import delete, exc, func, or_, select
 from sqlalchemy.orm.session import Session
 
 from airflow import settings
@@ -428,31 +428,27 @@ class DagFileProcessor(LoggingMixin):
         if not any(isinstance(ti.sla, timedelta) for ti in dag.tasks):
             cls.logger().info("Skipping SLA check for %s because no tasks in 
DAG have SLAs", dag)
             return
-
         qry = (
-            session.query(TI.task_id, 
func.max(DR.execution_date).label("max_ti"))
+            select(TI.task_id, func.max(DR.execution_date).label("max_ti"))
             .join(TI.dag_run)
-            .filter(TI.dag_id == dag.dag_id)
-            .filter(or_(TI.state == TaskInstanceState.SUCCESS, TI.state == 
TaskInstanceState.SKIPPED))
-            .filter(TI.task_id.in_(dag.task_ids))
+            .where(TI.dag_id == dag.dag_id)
+            .where(or_(TI.state == TaskInstanceState.SUCCESS, TI.state == 
TaskInstanceState.SKIPPED))
+            .where(TI.task_id.in_(dag.task_ids))
             .group_by(TI.task_id)
             .subquery("sq")
         )
         # get recorded SlaMiss
         recorded_slas_query = set(
-            session.query(SlaMiss.dag_id, SlaMiss.task_id, 
SlaMiss.execution_date).filter(
-                SlaMiss.dag_id == dag.dag_id, SlaMiss.task_id.in_(dag.task_ids)
+            session.execute(
+                select(SlaMiss.dag_id, SlaMiss.task_id, 
SlaMiss.execution_date).where(
+                    SlaMiss.dag_id == dag.dag_id, 
SlaMiss.task_id.in_(dag.task_ids)
+                )
             )
         )
-
-        max_tis: Iterator[TI] = (
-            session.query(TI)
+        max_tis: Iterator[TI] = session.scalars(
+            select(TI)
             .join(TI.dag_run)
-            .filter(
-                TI.dag_id == dag.dag_id,
-                TI.task_id == qry.c.task_id,
-                DR.execution_date == qry.c.max_ti,
-            )
+            .where(TI.dag_id == dag.dag_id, TI.task_id == qry.c.task_id, 
DR.execution_date == qry.c.max_ti)
         )
 
         ts = timezone.utcnow()
@@ -490,23 +486,18 @@ class DagFileProcessor(LoggingMixin):
             if sla_misses:
                 session.add_all(sla_misses)
         session.commit()
-
-        slas: list[SlaMiss] = (
-            session.query(SlaMiss)
-            .filter(SlaMiss.notification_sent == False, SlaMiss.dag_id == 
dag.dag_id)  # noqa
-            .all()
-        )
+        slas: list[SlaMiss] = session.scalars(
+            select(SlaMiss).where(~SlaMiss.notification_sent, SlaMiss.dag_id 
== dag.dag_id)
+        ).all()
         if slas:
             sla_dates: list[datetime] = [sla.execution_date for sla in slas]
-            fetched_tis: list[TI] = (
-                session.query(TI)
-                .filter(
+            fetched_tis: list[TI] = session.scalars(
+                select(TI).where(
                     TI.dag_id == dag.dag_id,
                     TI.execution_date.in_(sla_dates),
                     TI.state != TaskInstanceState.SUCCESS,
                 )
-                .all()
-            )
+            ).all()
             blocking_tis: list[TI] = []
             for ti in fetched_tis:
                 if ti.task_id in dag.task_ids:
diff --git a/airflow/datasets/manager.py b/airflow/datasets/manager.py
index 2a1d55bba9..c7e062803a 100644
--- a/airflow/datasets/manager.py
+++ b/airflow/datasets/manager.py
@@ -19,7 +19,7 @@ from __future__ import annotations
 
 from typing import TYPE_CHECKING
 
-from sqlalchemy import exc
+from sqlalchemy import exc, select
 from sqlalchemy.orm.session import Session
 
 from airflow.configuration import conf
@@ -52,7 +52,7 @@ class DatasetManager(LoggingMixin):
         For local datasets, look them up, record the dataset event, queue 
dagruns, and broadcast
         the dataset event
         """
-        dataset_model = session.query(DatasetModel).filter(DatasetModel.uri == 
dataset.uri).one_or_none()
+        dataset_model = 
session.scalar(select(DatasetModel).where(DatasetModel.uri == dataset.uri))
         if not dataset_model:
             self.log.warning("DatasetModel %s not found", dataset)
             return
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 75fee04145..80b29f8e4c 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -2029,8 +2029,7 @@ class DAG(LoggingMixin):
             raise ValueError("TaskGroup {group_id} could not be found")
         tasks_to_set_state = [task for task in task_group.iter_tasks() if 
isinstance(task, BaseOperator)]
         task_ids = [task.task_id for task in task_group.iter_tasks()]
-
-        dag_runs_query = session.query(DagRun.id).where(DagRun.dag_id == 
self.dag_id)
+        dag_runs_query = select(DagRun.id).where(DagRun.dag_id == self.dag_id)
         if start_date is None and end_date is None:
             dag_runs_query = dag_runs_query.where(DagRun.execution_date == 
start_date)
         else:
diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py
index c8feb58d91..0616d042b9 100644
--- a/airflow/models/skipmixin.py
+++ b/airflow/models/skipmixin.py
@@ -20,6 +20,8 @@ from __future__ import annotations
 import warnings
 from typing import TYPE_CHECKING, Iterable, Sequence
 
+from sqlalchemy import select, update
+
 from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
 from airflow.models.dagrun import DagRun
 from airflow.models.taskinstance import TaskInstance
@@ -67,24 +69,29 @@ class SkipMixin(LoggingMixin):
         """Set state of task instances to skipped from the same dag run."""
         if tasks:
             now = timezone.utcnow()
-            TI = TaskInstance
-            query = session.query(TI).filter(
-                TI.dag_id == dag_run.dag_id,
-                TI.run_id == dag_run.run_id,
-            )
+
             if isinstance(tasks[0], tuple):
-                query = query.filter(tuple_in_condition((TI.task_id, 
TI.map_index), tasks))
+                session.execute(
+                    update(TaskInstance)
+                    .where(
+                        TaskInstance.dag_id == dag_run.dag_id,
+                        TaskInstance.run_id == dag_run.run_id,
+                        tuple_in_condition((TaskInstance.task_id, 
TaskInstance.map_index), tasks),
+                    )
+                    .values(state=TaskInstanceState.SKIPPED, start_date=now, 
end_date=now)
+                    .execution_options(synchronize_session=False)
+                )
             else:
-                query = query.filter(TI.task_id.in_(tasks))
-
-            query.update(
-                {
-                    TaskInstance.state: TaskInstanceState.SKIPPED,
-                    TaskInstance.start_date: now,
-                    TaskInstance.end_date: now,
-                },
-                synchronize_session=False,
-            )
+                session.execute(
+                    update(TaskInstance)
+                    .where(
+                        TaskInstance.dag_id == dag_run.dag_id,
+                        TaskInstance.run_id == dag_run.run_id,
+                        TaskInstance.task_id.in_(tasks),
+                    )
+                    .values(state=TaskInstanceState.SKIPPED, start_date=now, 
end_date=now)
+                    .execution_options(synchronize_session=False)
+                )
 
     @provide_session
     def skip(
@@ -121,14 +128,12 @@ class SkipMixin(LoggingMixin):
                 stacklevel=2,
             )
 
-            dag_run = (
-                session.query(DagRun)
-                .filter(
-                    DagRun.dag_id == task_list[0].dag_id,
-                    DagRun.execution_date == execution_date,
+            dag_run = session.scalars(
+                select(DagRun).where(
+                    DagRun.dag_id == task_list[0].dag_id, 
DagRun.execution_date == execution_date
                 )
-                .one()
-            )
+            ).one()
+
         elif execution_date and dag_run and execution_date != 
dag_run.execution_date:
             raise ValueError(
                 "execution_date has a different value to  
dag_run.execution_date -- please only pass dag_run"
diff --git a/airflow/models/variable.py b/airflow/models/variable.py
index 5ca7774cbf..07cc9ffb6c 100644
--- a/airflow/models/variable.py
+++ b/airflow/models/variable.py
@@ -21,7 +21,7 @@ import json
 import logging
 from typing import Any
 
-from sqlalchemy import Boolean, Column, Integer, String, Text, delete
+from sqlalchemy import Boolean, Column, Integer, String, Text, delete, select
 from sqlalchemy.dialects.mysql import MEDIUMTEXT
 from sqlalchemy.orm import Session, declared_attr, reconstructor, synonym
 
@@ -200,8 +200,7 @@ class Variable(Base, LoggingMixin):
 
         if Variable.get_variable_from_secrets(key=key) is None:
             raise KeyError(f"Variable {key} does not exist")
-
-        obj = session.query(Variable).filter(Variable.key == key).first()
+        obj = session.scalar(select(Variable).where(Variable.key == key))
         if obj is None:
             raise AttributeError(f"Variable {key} does not exist in the 
Database and cannot be updated.")
 
diff --git a/airflow/operators/subdag.py b/airflow/operators/subdag.py
index 680497217d..345c36e72b 100644
--- a/airflow/operators/subdag.py
+++ b/airflow/operators/subdag.py
@@ -26,6 +26,7 @@ import warnings
 from datetime import datetime
 from enum import Enum
 
+from sqlalchemy import select
 from sqlalchemy.orm.session import Session
 
 from airflow.api.common.experimental.get_task_instance import get_task_instance
@@ -112,7 +113,7 @@ class SubDagOperator(BaseSensorOperator):
             conflicts = [t for t in self.subdag.tasks if t.pool == self.pool]
             if conflicts:
                 # only query for pool conflicts if one may exist
-                pool = session.query(Pool).filter(Pool.slots == 
1).filter(Pool.pool == self.pool).first()
+                pool = session.scalar(select(Pool).where(Pool.slots == 1, 
Pool.pool == self.pool))
                 if pool and any(t.pool == self.pool for t in 
self.subdag.tasks):
                     raise AirflowException(
                         f"SubDagOperator {self.task_id} and subdag task{'s' if 
len(conflicts) > 1 else ''} "
@@ -139,11 +140,11 @@ class SubDagOperator(BaseSensorOperator):
         with create_session() as session:
             dag_run.state = DagRunState.RUNNING
             session.merge(dag_run)
-            failed_task_instances = (
-                session.query(TaskInstance)
-                .filter(TaskInstance.dag_id == self.subdag.dag_id)
-                .filter(TaskInstance.execution_date == execution_date)
-                .filter(TaskInstance.state.in_((TaskInstanceState.FAILED, 
TaskInstanceState.UPSTREAM_FAILED)))
+            failed_task_instances = session.scalars(
+                select(TaskInstance)
+                .where(TaskInstance.dag_id == self.subdag.dag_id)
+                .where(TaskInstance.execution_date == execution_date)
+                .where(TaskInstance.state.in_((TaskInstanceState.FAILED, 
TaskInstanceState.UPSTREAM_FAILED)))
             )
 
             for task_instance in failed_task_instances:
diff --git a/airflow/providers/celery/executors/celery_executor_utils.py 
b/airflow/providers/celery/executors/celery_executor_utils.py
index 2e739f239a..5cd8ea7eb1 100644
--- a/airflow/providers/celery/executors/celery_executor_utils.py
+++ b/airflow/providers/celery/executors/celery_executor_utils.py
@@ -37,6 +37,7 @@ from celery.backends.database import DatabaseBackend, Task as 
TaskDb, retry, ses
 from celery.result import AsyncResult
 from celery.signals import import_modules as celery_import_modules
 from setproctitle import setproctitle
+from sqlalchemy import select
 
 import airflow.settings as settings
 from airflow.configuration import conf
@@ -268,7 +269,7 @@ class BulkStateFetcher(LoggingMixin):
         session = app.backend.ResultSession()
         task_cls = getattr(app.backend, "task_cls", TaskDb)
         with session_cleanup(session):
-            return 
session.query(task_cls).filter(task_cls.task_id.in_(task_ids)).all()
+            return 
session.scalars(select(task_cls).where(task_cls.task_id.in_(task_ids))).all()
 
     def _get_many_from_db_backend(self, async_tasks) -> Mapping[str, 
EventBufferValueType]:
         task_ids = self._tasks_list_to_task_ids(async_tasks)
diff --git a/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py 
b/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py
index cf74c34d97..8e95f7a6c9 100644
--- a/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py
+++ b/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py
@@ -34,6 +34,7 @@ from datetime import datetime
 from queue import Empty, Queue
 from typing import TYPE_CHECKING, Any, Sequence
 
+from sqlalchemy import select, update
 from sqlalchemy.orm import Session
 
 from airflow import AirflowException
@@ -215,12 +216,12 @@ class KubernetesExecutor(BaseExecutor):
         from airflow.models.taskinstance import TaskInstance
 
         self.log.debug("Clearing tasks that have not been launched")
-        query = session.query(TaskInstance).filter(
+        query = select(TaskInstance).where(
             TaskInstance.state == TaskInstanceState.QUEUED, 
TaskInstance.queued_by_job_id == self.job_id
         )
         if self.kubernetes_queue:
-            query = query.filter(TaskInstance.queue == self.kubernetes_queue)
-        queued_tis: list[TaskInstance] = query.all()
+            query = query.where(TaskInstance.queue == self.kubernetes_queue)
+        queued_tis: list[TaskInstance] = session.scalars(query).all()
         self.log.info("Found %s queued task instances", len(queued_tis))
 
         # Go through the "last seen" dictionary and clean out old entries
@@ -262,12 +263,16 @@ class KubernetesExecutor(BaseExecutor):
             if pod_list:
                 continue
             self.log.info("TaskInstance: %s found in queued state but was not 
launched, rescheduling", ti)
-            session.query(TaskInstance).filter(
-                TaskInstance.dag_id == ti.dag_id,
-                TaskInstance.task_id == ti.task_id,
-                TaskInstance.run_id == ti.run_id,
-                TaskInstance.map_index == ti.map_index,
-            ).update({TaskInstance.state: TaskInstanceState.SCHEDULED})
+            session.execute(
+                update(TaskInstance)
+                .where(
+                    TaskInstance.dag_id == ti.dag_id,
+                    TaskInstance.task_id == ti.task_id,
+                    TaskInstance.run_id == ti.run_id,
+                    TaskInstance.map_index == ti.map_index,
+                )
+                .values(state=TaskInstanceState.SCHEDULED)
+            )
 
     def start(self) -> None:
         """Starts the executor."""
@@ -457,7 +462,7 @@ class KubernetesExecutor(BaseExecutor):
         if state is None:
             from airflow.models.taskinstance import TaskInstance
 
-            state = 
session.query(TaskInstance.state).filter(TaskInstance.filter_for_tis([key])).scalar()
+            state = 
session.scalar(select(TaskInstance.state).where(TaskInstance.filter_for_tis([key])))
             state = TaskInstanceState(state)
 
         self.event_buffer[key] = state, None
diff --git a/airflow/utils/scheduler_health.py 
b/airflow/utils/scheduler_health.py
index 6108aee018..0a068d7f72 100644
--- a/airflow/utils/scheduler_health.py
+++ b/airflow/utils/scheduler_health.py
@@ -19,6 +19,8 @@ from __future__ import annotations
 import logging
 from http.server import BaseHTTPRequestHandler, HTTPServer
 
+from sqlalchemy import select
+
 from airflow.configuration import conf
 from airflow.jobs.job import Job
 from airflow.jobs.scheduler_job_runner import SchedulerJobRunner
@@ -35,12 +37,12 @@ class HealthServer(BaseHTTPRequestHandler):
         if self.path == "/health":
             try:
                 with create_session() as session:
-                    scheduler_job = (
-                        session.query(Job)
+                    scheduler_job = session.scalar(
+                        select(Job)
                         .filter_by(job_type=SchedulerJobRunner.job_type)
                         .filter_by(hostname=get_hostname())
                         .order_by(Job.latest_heartbeat.desc())
-                        .first()
+                        .limit(1)
                     )
                 if scheduler_job and scheduler_job.is_alive():
                     self.send_response(200)
diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py
index f97570b94c..bb2277e4ed 100644
--- a/airflow/utils/sqlalchemy.py
+++ b/airflow/utils/sqlalchemy.py
@@ -436,7 +436,7 @@ def lock_rows(query: Query, session: Session) -> 
Generator[None, None, None]:
 
     :meta private:
     """
-    locked_rows = with_row_locks(query, session).all()
+    locked_rows = with_row_locks(query, session)
     yield
     del locked_rows
 
diff --git a/airflow/www/security.py b/airflow/www/security.py
index a988dccd2e..93f8fd11fc 100644
--- a/airflow/www/security.py
+++ b/airflow/www/security.py
@@ -20,7 +20,7 @@ import warnings
 from typing import TYPE_CHECKING, Any, Collection, Container, Iterable, 
Sequence
 
 from flask import g
-from sqlalchemy import or_
+from sqlalchemy import or_, select
 from sqlalchemy.orm import Session, joinedload
 
 from airflow.auth.managers.fab.models import Permission, Resource, Role, User
@@ -243,11 +243,9 @@ class AirflowSecurityManager(SecurityManagerOverride, 
SecurityManager, LoggingMi
 
     def _get_root_dag_id(self, dag_id: str) -> str:
         if "." in dag_id:
-            dm = (
-                self.appbuilder.get_session.query(DagModel.dag_id, 
DagModel.root_dag_id)
-                .filter(DagModel.dag_id == dag_id)
-                .first()
-            )
+            dm = self.appbuilder.get_session.execute(
+                select(DagModel.dag_id, 
DagModel.root_dag_id).where(DagModel.dag_id == dag_id)
+            ).one()
             return dm.root_dag_id or dm.dag_id
         return dag_id
 
@@ -331,7 +329,7 @@ class AirflowSecurityManager(SecurityManagerOverride, 
SecurityManager, LoggingMi
             stacklevel=3,
         )
         dag_ids = self.get_accessible_dag_ids(user, user_actions, session)
-        return session.query(DagModel).filter(DagModel.dag_id.in_(dag_ids))
+        return 
session.scalars(select(DagModel).where(DagModel.dag_id.in_(dag_ids)))
 
     def get_readable_dag_ids(self, user) -> set[str]:
         """Gets the DAG IDs readable by authenticated user."""
@@ -358,16 +356,15 @@ class AirflowSecurityManager(SecurityManagerOverride, 
SecurityManager, LoggingMi
             if (permissions.ACTION_CAN_EDIT in user_actions and 
self.can_edit_all_dags(user)) or (
                 permissions.ACTION_CAN_READ in user_actions and 
self.can_read_all_dags(user)
             ):
-                return {dag.dag_id for dag in session.query(DagModel.dag_id)}
-            user_query = (
-                session.query(User)
+                return {dag.dag_id for dag in 
session.execute(select(DagModel.dag_id))}
+            user_query = session.scalar(
+                select(User)
                 .options(
                     joinedload(User.roles)
                     .subqueryload(Role.permissions)
                     .options(joinedload(Permission.action), 
joinedload(Permission.resource))
                 )
-                .filter(User.id == user.id)
-                .first()
+                .where(User.id == user.id)
             )
             roles = user_query.roles
 
@@ -380,13 +377,16 @@ class AirflowSecurityManager(SecurityManagerOverride, 
SecurityManager, LoggingMi
 
                 resource = permission.resource.name
                 if resource == permissions.RESOURCE_DAG:
-                    return {dag.dag_id for dag in 
session.query(DagModel.dag_id)}
+                    return {dag.dag_id for dag in 
session.execute(select(DagModel.dag_id))}
 
                 if resource.startswith(permissions.RESOURCE_DAG_PREFIX):
                     
resources.add(resource[len(permissions.RESOURCE_DAG_PREFIX) :])
                 else:
                     resources.add(resource)
-        return {dag.dag_id for dag in 
session.query(DagModel.dag_id).filter(DagModel.dag_id.in_(resources))}
+        return {
+            dag.dag_id
+            for dag in 
session.execute(select(DagModel.dag_id).where(DagModel.dag_id.in_(resources)))
+        }
 
     def can_access_some_dags(self, action: str, dag_id: str | None = None) -> 
bool:
         """Checks if user has read or write access to some dags."""
@@ -524,10 +524,8 @@ class AirflowSecurityManager(SecurityManagerOverride, 
SecurityManager, LoggingMi
         resource = self.get_resource(resource_name)
         perm = None
         if action and resource:
-            perm = (
-                self.appbuilder.get_session.query(self.permission_model)
-                .filter_by(action=action, resource=resource)
-                .first()
+            perm = self.appbuilder.get_session.scalar(
+                select(self.permission_model).filter_by(action=action, 
resource=resource).limit(1)
             )
         if not perm and action_name and resource_name:
             self.create_permission(action_name, resource_name)
@@ -548,11 +546,11 @@ class AirflowSecurityManager(SecurityManagerOverride, 
SecurityManager, LoggingMi
     def get_all_permissions(self) -> set[tuple[str, str]]:
         """Returns all permissions as a set of tuples with the action and 
resource names."""
         return set(
-            self.appbuilder.get_session.query(self.permission_model)
-            .join(self.permission_model.action)
-            .join(self.permission_model.resource)
-            .with_entities(self.action_model.name, self.resource_model.name)
-            .all()
+            self.appbuilder.get_session.execute(
+                select(self.action_model.name, self.resource_model.name)
+                .join(self.permission_model.action)
+                .join(self.permission_model.resource)
+            )
         )
 
     def _get_all_non_dag_permissions(self) -> dict[tuple[str, str], 
Permission]:
@@ -565,12 +563,12 @@ class AirflowSecurityManager(SecurityManagerOverride, 
SecurityManager, LoggingMi
         return {
             (action_name, resource_name): viewmodel
             for action_name, resource_name, viewmodel in (
-                self.appbuilder.get_session.query(self.permission_model)
-                .join(self.permission_model.action)
-                .join(self.permission_model.resource)
-                
.filter(~self.resource_model.name.like(f"{permissions.RESOURCE_DAG_PREFIX}%"))
-                .with_entities(self.action_model.name, 
self.resource_model.name, self.permission_model)
-                .all()
+                self.appbuilder.get_session.execute(
+                    select(self.action_model.name, self.resource_model.name, 
self.permission_model)
+                    .join(self.permission_model.action)
+                    .join(self.permission_model.resource)
+                    
.where(~self.resource_model.name.like(f"{permissions.RESOURCE_DAG_PREFIX}%"))
+                )
             )
         }
 
@@ -578,9 +576,9 @@ class AirflowSecurityManager(SecurityManagerOverride, 
SecurityManager, LoggingMi
         """Returns a dict with a key of role name and value of role with early 
loaded permissions."""
         return {
             r.name: r
-            for r in 
self.appbuilder.get_session.query(self.role_model).options(
-                joinedload(self.role_model.permissions)
-            )
+            for r in self.appbuilder.get_session.scalars(
+                
select(self.role_model).options(joinedload(self.role_model.permissions))
+            ).unique()
         }
 
     def create_dag_specific_permissions(self) -> None:
@@ -621,12 +619,12 @@ class AirflowSecurityManager(SecurityManagerOverride, 
SecurityManager, LoggingMi
         :return: None.
         """
         session = self.appbuilder.get_session
-        dag_resources = session.query(Resource).filter(
-            Resource.name.like(f"{permissions.RESOURCE_DAG_PREFIX}%")
+        dag_resources = session.scalars(
+            
select(Resource).where(Resource.name.like(f"{permissions.RESOURCE_DAG_PREFIX}%"))
         )
         resource_ids = [resource.id for resource in dag_resources]
 
-        perms = 
session.query(Permission).filter(~Permission.resource_id.in_(resource_ids))
+        perms = 
session.scalars(select(Permission).where(~Permission.resource_id.in_(resource_ids)))
         perms = [p for p in perms if p.action and p.resource]
 
         admin = self.find_role("Admin")
diff --git a/tests/datasets/test_manager.py b/tests/datasets/test_manager.py
index e462d0093e..10d54744b9 100644
--- a/tests/datasets/test_manager.py
+++ b/tests/datasets/test_manager.py
@@ -54,7 +54,7 @@ class TestDatasetManager:
 
         mock_session = mock.Mock()
         # Gotta mock up the query results
-        
mock_session.query.return_value.filter.return_value.one_or_none.return_value = 
None
+        mock_session.scalar.return_value = None
 
         dsem.register_dataset_change(task_instance=mock_task_instance, 
dataset=dataset, session=mock_session)
 
diff --git a/tests/integration/executors/test_celery_executor.py 
b/tests/integration/executors/test_celery_executor.py
index 9e8d365c5c..26c2af03db 100644
--- a/tests/integration/executors/test_celery_executor.py
+++ b/tests/integration/executors/test_celery_executor.py
@@ -311,7 +311,7 @@ class TestBulkStateFetcher:
             ):
                 caplog.clear()
                 mock_session = mock_backend.ResultSession.return_value
-                
mock_session.query.return_value.filter.return_value.all.return_value = [
+                mock_session.scalars.return_value.all.return_value = [
                     mock.MagicMock(**{"to_dict.return_value": {"status": 
"SUCCESS", "task_id": "123"}})
                 ]
 
@@ -340,7 +340,7 @@ class TestBulkStateFetcher:
             ):
                 caplog.clear()
                 mock_session = mock_backend.ResultSession.return_value
-                mock_retry_db_result = 
mock_session.query.return_value.filter.return_value.all
+                mock_retry_db_result = mock_session.scalars.return_value.all
                 mock_retry_db_result.return_value = [
                     mock.MagicMock(**{"to_dict.return_value": {"status": 
"SUCCESS", "task_id": "123"}})
                 ]

Reply via email to