This is an automated email from the ASF dual-hosted git repository.
phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 4c4981d1ad Refactor Sqlalchemy queries to 2.0 style (Part 7) (#32883)
4c4981d1ad is described below
commit 4c4981d1adf2bd8b28ffa7e6ed57162abb8feb8f
Author: Phani Kumar <[email protected]>
AuthorDate: Mon Aug 21 19:04:38 2023 +0530
Refactor Sqlalchemy queries to 2.0 style (Part 7) (#32883)
---
airflow/api/common/experimental/pool.py | 2 +-
airflow/auth/managers/fab/models/__init__.py | 14 +++--
airflow/dag_processing/processor.py | 45 ++++++--------
airflow/datasets/manager.py | 4 +-
airflow/models/dag.py | 3 +-
airflow/models/skipmixin.py | 51 ++++++++--------
airflow/models/variable.py | 5 +-
airflow/operators/subdag.py | 13 +++--
.../celery/executors/celery_executor_utils.py | 3 +-
.../kubernetes/executors/kubernetes_executor.py | 25 ++++----
airflow/utils/scheduler_health.py | 8 ++-
airflow/utils/sqlalchemy.py | 2 +-
airflow/www/security.py | 68 +++++++++++-----------
tests/datasets/test_manager.py | 2 +-
.../integration/executors/test_celery_executor.py | 4 +-
15 files changed, 126 insertions(+), 123 deletions(-)
diff --git a/airflow/api/common/experimental/pool.py
b/airflow/api/common/experimental/pool.py
index 34e35cd435..1134bda989 100644
--- a/airflow/api/common/experimental/pool.py
+++ b/airflow/api/common/experimental/pool.py
@@ -45,7 +45,7 @@ def get_pool(name, session: Session = NEW_SESSION):
@provide_session
def get_pools(session: Session = NEW_SESSION):
"""Get all pools."""
- return session.query(Pool).all()
+ return session.scalars(select(Pool)).all()
@deprecated(reason="Use Pool.create_pool() instead", version="2.2.4")
diff --git a/airflow/auth/managers/fab/models/__init__.py
b/airflow/auth/managers/fab/models/__init__.py
index 0bc26adb7e..28de7840ec 100644
--- a/airflow/auth/managers/fab/models/__init__.py
+++ b/airflow/auth/managers/fab/models/__init__.py
@@ -37,6 +37,7 @@ from sqlalchemy import (
UniqueConstraint,
event,
func,
+ select,
)
from sqlalchemy.orm import backref, declared_attr, relationship
@@ -210,12 +211,13 @@ class User(Model, BaseUser):
if current_app:
sm = current_app.appbuilder.sm
self._perms: set[tuple[str, str]] = set(
- sm.get_session.query(sm.action_model.name,
sm.resource_model.name)
- .join(sm.permission_model.action)
- .join(sm.permission_model.resource)
- .join(sm.permission_model.role)
- .filter(sm.role_model.user.contains(self))
- .all()
+ sm.get_session.execute(
+ select(sm.action_model.name, sm.resource_model.name)
+ .join(sm.permission_model.action)
+ .join(sm.permission_model.resource)
+ .join(sm.permission_model.role)
+ .where(sm.role_model.user.contains(self))
+ )
)
else:
self._perms = {
diff --git a/airflow/dag_processing/processor.py
b/airflow/dag_processing/processor.py
index 162fc5889c..64858be94e 100644
--- a/airflow/dag_processing/processor.py
+++ b/airflow/dag_processing/processor.py
@@ -30,7 +30,7 @@ from multiprocessing.connection import Connection as
MultiprocessingConnection
from typing import TYPE_CHECKING, Iterable, Iterator
from setproctitle import setproctitle
-from sqlalchemy import delete, exc, func, or_
+from sqlalchemy import delete, exc, func, or_, select
from sqlalchemy.orm.session import Session
from airflow import settings
@@ -428,31 +428,27 @@ class DagFileProcessor(LoggingMixin):
if not any(isinstance(ti.sla, timedelta) for ti in dag.tasks):
cls.logger().info("Skipping SLA check for %s because no tasks in
DAG have SLAs", dag)
return
-
qry = (
- session.query(TI.task_id,
func.max(DR.execution_date).label("max_ti"))
+ select(TI.task_id, func.max(DR.execution_date).label("max_ti"))
.join(TI.dag_run)
- .filter(TI.dag_id == dag.dag_id)
- .filter(or_(TI.state == TaskInstanceState.SUCCESS, TI.state ==
TaskInstanceState.SKIPPED))
- .filter(TI.task_id.in_(dag.task_ids))
+ .where(TI.dag_id == dag.dag_id)
+ .where(or_(TI.state == TaskInstanceState.SUCCESS, TI.state ==
TaskInstanceState.SKIPPED))
+ .where(TI.task_id.in_(dag.task_ids))
.group_by(TI.task_id)
.subquery("sq")
)
# get recorded SlaMiss
recorded_slas_query = set(
- session.query(SlaMiss.dag_id, SlaMiss.task_id,
SlaMiss.execution_date).filter(
- SlaMiss.dag_id == dag.dag_id, SlaMiss.task_id.in_(dag.task_ids)
+ session.execute(
+ select(SlaMiss.dag_id, SlaMiss.task_id,
SlaMiss.execution_date).where(
+ SlaMiss.dag_id == dag.dag_id,
SlaMiss.task_id.in_(dag.task_ids)
+ )
)
)
-
- max_tis: Iterator[TI] = (
- session.query(TI)
+ max_tis: Iterator[TI] = session.scalars(
+ select(TI)
.join(TI.dag_run)
- .filter(
- TI.dag_id == dag.dag_id,
- TI.task_id == qry.c.task_id,
- DR.execution_date == qry.c.max_ti,
- )
+ .where(TI.dag_id == dag.dag_id, TI.task_id == qry.c.task_id,
DR.execution_date == qry.c.max_ti)
)
ts = timezone.utcnow()
@@ -490,23 +486,18 @@ class DagFileProcessor(LoggingMixin):
if sla_misses:
session.add_all(sla_misses)
session.commit()
-
- slas: list[SlaMiss] = (
- session.query(SlaMiss)
- .filter(SlaMiss.notification_sent == False, SlaMiss.dag_id ==
dag.dag_id) # noqa
- .all()
- )
+ slas: list[SlaMiss] = session.scalars(
+ select(SlaMiss).where(~SlaMiss.notification_sent, SlaMiss.dag_id
== dag.dag_id)
+ ).all()
if slas:
sla_dates: list[datetime] = [sla.execution_date for sla in slas]
- fetched_tis: list[TI] = (
- session.query(TI)
- .filter(
+ fetched_tis: list[TI] = session.scalars(
+ select(TI).where(
TI.dag_id == dag.dag_id,
TI.execution_date.in_(sla_dates),
TI.state != TaskInstanceState.SUCCESS,
)
- .all()
- )
+ ).all()
blocking_tis: list[TI] = []
for ti in fetched_tis:
if ti.task_id in dag.task_ids:
diff --git a/airflow/datasets/manager.py b/airflow/datasets/manager.py
index 2a1d55bba9..c7e062803a 100644
--- a/airflow/datasets/manager.py
+++ b/airflow/datasets/manager.py
@@ -19,7 +19,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING
-from sqlalchemy import exc
+from sqlalchemy import exc, select
from sqlalchemy.orm.session import Session
from airflow.configuration import conf
@@ -52,7 +52,7 @@ class DatasetManager(LoggingMixin):
For local datasets, look them up, record the dataset event, queue
dagruns, and broadcast
the dataset event
"""
- dataset_model = session.query(DatasetModel).filter(DatasetModel.uri ==
dataset.uri).one_or_none()
+ dataset_model =
session.scalar(select(DatasetModel).where(DatasetModel.uri == dataset.uri))
if not dataset_model:
self.log.warning("DatasetModel %s not found", dataset)
return
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 75fee04145..80b29f8e4c 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -2029,8 +2029,7 @@ class DAG(LoggingMixin):
raise ValueError("TaskGroup {group_id} could not be found")
tasks_to_set_state = [task for task in task_group.iter_tasks() if
isinstance(task, BaseOperator)]
task_ids = [task.task_id for task in task_group.iter_tasks()]
-
- dag_runs_query = session.query(DagRun.id).where(DagRun.dag_id ==
self.dag_id)
+ dag_runs_query = select(DagRun.id).where(DagRun.dag_id == self.dag_id)
if start_date is None and end_date is None:
dag_runs_query = dag_runs_query.where(DagRun.execution_date ==
start_date)
else:
diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py
index c8feb58d91..0616d042b9 100644
--- a/airflow/models/skipmixin.py
+++ b/airflow/models/skipmixin.py
@@ -20,6 +20,8 @@ from __future__ import annotations
import warnings
from typing import TYPE_CHECKING, Iterable, Sequence
+from sqlalchemy import select, update
+
from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
@@ -67,24 +69,29 @@ class SkipMixin(LoggingMixin):
"""Set state of task instances to skipped from the same dag run."""
if tasks:
now = timezone.utcnow()
- TI = TaskInstance
- query = session.query(TI).filter(
- TI.dag_id == dag_run.dag_id,
- TI.run_id == dag_run.run_id,
- )
+
if isinstance(tasks[0], tuple):
- query = query.filter(tuple_in_condition((TI.task_id,
TI.map_index), tasks))
+ session.execute(
+ update(TaskInstance)
+ .where(
+ TaskInstance.dag_id == dag_run.dag_id,
+ TaskInstance.run_id == dag_run.run_id,
+ tuple_in_condition((TaskInstance.task_id,
TaskInstance.map_index), tasks),
+ )
+ .values(state=TaskInstanceState.SKIPPED, start_date=now,
end_date=now)
+ .execution_options(synchronize_session=False)
+ )
else:
- query = query.filter(TI.task_id.in_(tasks))
-
- query.update(
- {
- TaskInstance.state: TaskInstanceState.SKIPPED,
- TaskInstance.start_date: now,
- TaskInstance.end_date: now,
- },
- synchronize_session=False,
- )
+ session.execute(
+ update(TaskInstance)
+ .where(
+ TaskInstance.dag_id == dag_run.dag_id,
+ TaskInstance.run_id == dag_run.run_id,
+ TaskInstance.task_id.in_(tasks),
+ )
+ .values(state=TaskInstanceState.SKIPPED, start_date=now,
end_date=now)
+ .execution_options(synchronize_session=False)
+ )
@provide_session
def skip(
@@ -121,14 +128,12 @@ class SkipMixin(LoggingMixin):
stacklevel=2,
)
- dag_run = (
- session.query(DagRun)
- .filter(
- DagRun.dag_id == task_list[0].dag_id,
- DagRun.execution_date == execution_date,
+ dag_run = session.scalars(
+ select(DagRun).where(
+ DagRun.dag_id == task_list[0].dag_id,
DagRun.execution_date == execution_date
)
- .one()
- )
+ ).one()
+
elif execution_date and dag_run and execution_date !=
dag_run.execution_date:
raise ValueError(
"execution_date has a different value to
dag_run.execution_date -- please only pass dag_run"
diff --git a/airflow/models/variable.py b/airflow/models/variable.py
index 5ca7774cbf..07cc9ffb6c 100644
--- a/airflow/models/variable.py
+++ b/airflow/models/variable.py
@@ -21,7 +21,7 @@ import json
import logging
from typing import Any
-from sqlalchemy import Boolean, Column, Integer, String, Text, delete
+from sqlalchemy import Boolean, Column, Integer, String, Text, delete, select
from sqlalchemy.dialects.mysql import MEDIUMTEXT
from sqlalchemy.orm import Session, declared_attr, reconstructor, synonym
@@ -200,8 +200,7 @@ class Variable(Base, LoggingMixin):
if Variable.get_variable_from_secrets(key=key) is None:
raise KeyError(f"Variable {key} does not exist")
-
- obj = session.query(Variable).filter(Variable.key == key).first()
+ obj = session.scalar(select(Variable).where(Variable.key == key))
if obj is None:
raise AttributeError(f"Variable {key} does not exist in the
Database and cannot be updated.")
diff --git a/airflow/operators/subdag.py b/airflow/operators/subdag.py
index 680497217d..345c36e72b 100644
--- a/airflow/operators/subdag.py
+++ b/airflow/operators/subdag.py
@@ -26,6 +26,7 @@ import warnings
from datetime import datetime
from enum import Enum
+from sqlalchemy import select
from sqlalchemy.orm.session import Session
from airflow.api.common.experimental.get_task_instance import get_task_instance
@@ -112,7 +113,7 @@ class SubDagOperator(BaseSensorOperator):
conflicts = [t for t in self.subdag.tasks if t.pool == self.pool]
if conflicts:
# only query for pool conflicts if one may exist
- pool = session.query(Pool).filter(Pool.slots ==
1).filter(Pool.pool == self.pool).first()
+ pool = session.scalar(select(Pool).where(Pool.slots == 1,
Pool.pool == self.pool))
if pool and any(t.pool == self.pool for t in
self.subdag.tasks):
raise AirflowException(
f"SubDagOperator {self.task_id} and subdag task{'s' if
len(conflicts) > 1 else ''} "
@@ -139,11 +140,11 @@ class SubDagOperator(BaseSensorOperator):
with create_session() as session:
dag_run.state = DagRunState.RUNNING
session.merge(dag_run)
- failed_task_instances = (
- session.query(TaskInstance)
- .filter(TaskInstance.dag_id == self.subdag.dag_id)
- .filter(TaskInstance.execution_date == execution_date)
- .filter(TaskInstance.state.in_((TaskInstanceState.FAILED,
TaskInstanceState.UPSTREAM_FAILED)))
+ failed_task_instances = session.scalars(
+ select(TaskInstance)
+ .where(TaskInstance.dag_id == self.subdag.dag_id)
+ .where(TaskInstance.execution_date == execution_date)
+ .where(TaskInstance.state.in_((TaskInstanceState.FAILED,
TaskInstanceState.UPSTREAM_FAILED)))
)
for task_instance in failed_task_instances:
diff --git a/airflow/providers/celery/executors/celery_executor_utils.py
b/airflow/providers/celery/executors/celery_executor_utils.py
index 2e739f239a..5cd8ea7eb1 100644
--- a/airflow/providers/celery/executors/celery_executor_utils.py
+++ b/airflow/providers/celery/executors/celery_executor_utils.py
@@ -37,6 +37,7 @@ from celery.backends.database import DatabaseBackend, Task as
TaskDb, retry, ses
from celery.result import AsyncResult
from celery.signals import import_modules as celery_import_modules
from setproctitle import setproctitle
+from sqlalchemy import select
import airflow.settings as settings
from airflow.configuration import conf
@@ -268,7 +269,7 @@ class BulkStateFetcher(LoggingMixin):
session = app.backend.ResultSession()
task_cls = getattr(app.backend, "task_cls", TaskDb)
with session_cleanup(session):
- return
session.query(task_cls).filter(task_cls.task_id.in_(task_ids)).all()
+ return
session.scalars(select(task_cls).where(task_cls.task_id.in_(task_ids))).all()
def _get_many_from_db_backend(self, async_tasks) -> Mapping[str,
EventBufferValueType]:
task_ids = self._tasks_list_to_task_ids(async_tasks)
diff --git a/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py
b/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py
index cf74c34d97..8e95f7a6c9 100644
--- a/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py
+++ b/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py
@@ -34,6 +34,7 @@ from datetime import datetime
from queue import Empty, Queue
from typing import TYPE_CHECKING, Any, Sequence
+from sqlalchemy import select, update
from sqlalchemy.orm import Session
from airflow import AirflowException
@@ -215,12 +216,12 @@ class KubernetesExecutor(BaseExecutor):
from airflow.models.taskinstance import TaskInstance
self.log.debug("Clearing tasks that have not been launched")
- query = session.query(TaskInstance).filter(
+ query = select(TaskInstance).where(
TaskInstance.state == TaskInstanceState.QUEUED,
TaskInstance.queued_by_job_id == self.job_id
)
if self.kubernetes_queue:
- query = query.filter(TaskInstance.queue == self.kubernetes_queue)
- queued_tis: list[TaskInstance] = query.all()
+ query = query.where(TaskInstance.queue == self.kubernetes_queue)
+ queued_tis: list[TaskInstance] = session.scalars(query).all()
self.log.info("Found %s queued task instances", len(queued_tis))
# Go through the "last seen" dictionary and clean out old entries
@@ -262,12 +263,16 @@ class KubernetesExecutor(BaseExecutor):
if pod_list:
continue
self.log.info("TaskInstance: %s found in queued state but was not
launched, rescheduling", ti)
- session.query(TaskInstance).filter(
- TaskInstance.dag_id == ti.dag_id,
- TaskInstance.task_id == ti.task_id,
- TaskInstance.run_id == ti.run_id,
- TaskInstance.map_index == ti.map_index,
- ).update({TaskInstance.state: TaskInstanceState.SCHEDULED})
+ session.execute(
+ update(TaskInstance)
+ .where(
+ TaskInstance.dag_id == ti.dag_id,
+ TaskInstance.task_id == ti.task_id,
+ TaskInstance.run_id == ti.run_id,
+ TaskInstance.map_index == ti.map_index,
+ )
+ .values(state=TaskInstanceState.SCHEDULED)
+ )
def start(self) -> None:
"""Starts the executor."""
@@ -457,7 +462,7 @@ class KubernetesExecutor(BaseExecutor):
if state is None:
from airflow.models.taskinstance import TaskInstance
- state =
session.query(TaskInstance.state).filter(TaskInstance.filter_for_tis([key])).scalar()
+ state =
session.scalar(select(TaskInstance.state).where(TaskInstance.filter_for_tis([key])))
state = TaskInstanceState(state)
self.event_buffer[key] = state, None
diff --git a/airflow/utils/scheduler_health.py
b/airflow/utils/scheduler_health.py
index 6108aee018..0a068d7f72 100644
--- a/airflow/utils/scheduler_health.py
+++ b/airflow/utils/scheduler_health.py
@@ -19,6 +19,8 @@ from __future__ import annotations
import logging
from http.server import BaseHTTPRequestHandler, HTTPServer
+from sqlalchemy import select
+
from airflow.configuration import conf
from airflow.jobs.job import Job
from airflow.jobs.scheduler_job_runner import SchedulerJobRunner
@@ -35,12 +37,12 @@ class HealthServer(BaseHTTPRequestHandler):
if self.path == "/health":
try:
with create_session() as session:
- scheduler_job = (
- session.query(Job)
+ scheduler_job = session.scalar(
+ select(Job)
.filter_by(job_type=SchedulerJobRunner.job_type)
.filter_by(hostname=get_hostname())
.order_by(Job.latest_heartbeat.desc())
- .first()
+ .limit(1)
)
if scheduler_job and scheduler_job.is_alive():
self.send_response(200)
diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py
index f97570b94c..bb2277e4ed 100644
--- a/airflow/utils/sqlalchemy.py
+++ b/airflow/utils/sqlalchemy.py
@@ -436,7 +436,7 @@ def lock_rows(query: Query, session: Session) ->
Generator[None, None, None]:
:meta private:
"""
- locked_rows = with_row_locks(query, session).all()
+ locked_rows = with_row_locks(query, session)
yield
del locked_rows
diff --git a/airflow/www/security.py b/airflow/www/security.py
index a988dccd2e..93f8fd11fc 100644
--- a/airflow/www/security.py
+++ b/airflow/www/security.py
@@ -20,7 +20,7 @@ import warnings
from typing import TYPE_CHECKING, Any, Collection, Container, Iterable,
Sequence
from flask import g
-from sqlalchemy import or_
+from sqlalchemy import or_, select
from sqlalchemy.orm import Session, joinedload
from airflow.auth.managers.fab.models import Permission, Resource, Role, User
@@ -243,11 +243,9 @@ class AirflowSecurityManager(SecurityManagerOverride,
SecurityManager, LoggingMi
def _get_root_dag_id(self, dag_id: str) -> str:
if "." in dag_id:
- dm = (
- self.appbuilder.get_session.query(DagModel.dag_id,
DagModel.root_dag_id)
- .filter(DagModel.dag_id == dag_id)
- .first()
- )
+ dm = self.appbuilder.get_session.execute(
+ select(DagModel.dag_id,
DagModel.root_dag_id).where(DagModel.dag_id == dag_id)
+ ).one()
return dm.root_dag_id or dm.dag_id
return dag_id
@@ -331,7 +329,7 @@ class AirflowSecurityManager(SecurityManagerOverride,
SecurityManager, LoggingMi
stacklevel=3,
)
dag_ids = self.get_accessible_dag_ids(user, user_actions, session)
- return session.query(DagModel).filter(DagModel.dag_id.in_(dag_ids))
+ return
session.scalars(select(DagModel).where(DagModel.dag_id.in_(dag_ids)))
def get_readable_dag_ids(self, user) -> set[str]:
"""Gets the DAG IDs readable by authenticated user."""
@@ -358,16 +356,15 @@ class AirflowSecurityManager(SecurityManagerOverride,
SecurityManager, LoggingMi
if (permissions.ACTION_CAN_EDIT in user_actions and
self.can_edit_all_dags(user)) or (
permissions.ACTION_CAN_READ in user_actions and
self.can_read_all_dags(user)
):
- return {dag.dag_id for dag in session.query(DagModel.dag_id)}
- user_query = (
- session.query(User)
+ return {dag.dag_id for dag in
session.execute(select(DagModel.dag_id))}
+ user_query = session.scalar(
+ select(User)
.options(
joinedload(User.roles)
.subqueryload(Role.permissions)
.options(joinedload(Permission.action),
joinedload(Permission.resource))
)
- .filter(User.id == user.id)
- .first()
+ .where(User.id == user.id)
)
roles = user_query.roles
@@ -380,13 +377,16 @@ class AirflowSecurityManager(SecurityManagerOverride,
SecurityManager, LoggingMi
resource = permission.resource.name
if resource == permissions.RESOURCE_DAG:
- return {dag.dag_id for dag in
session.query(DagModel.dag_id)}
+ return {dag.dag_id for dag in
session.execute(select(DagModel.dag_id))}
if resource.startswith(permissions.RESOURCE_DAG_PREFIX):
resources.add(resource[len(permissions.RESOURCE_DAG_PREFIX) :])
else:
resources.add(resource)
- return {dag.dag_id for dag in
session.query(DagModel.dag_id).filter(DagModel.dag_id.in_(resources))}
+ return {
+ dag.dag_id
+ for dag in
session.execute(select(DagModel.dag_id).where(DagModel.dag_id.in_(resources)))
+ }
def can_access_some_dags(self, action: str, dag_id: str | None = None) ->
bool:
"""Checks if user has read or write access to some dags."""
@@ -524,10 +524,8 @@ class AirflowSecurityManager(SecurityManagerOverride,
SecurityManager, LoggingMi
resource = self.get_resource(resource_name)
perm = None
if action and resource:
- perm = (
- self.appbuilder.get_session.query(self.permission_model)
- .filter_by(action=action, resource=resource)
- .first()
+ perm = self.appbuilder.get_session.scalar(
+ select(self.permission_model).filter_by(action=action,
resource=resource).limit(1)
)
if not perm and action_name and resource_name:
self.create_permission(action_name, resource_name)
@@ -548,11 +546,11 @@ class AirflowSecurityManager(SecurityManagerOverride,
SecurityManager, LoggingMi
def get_all_permissions(self) -> set[tuple[str, str]]:
"""Returns all permissions as a set of tuples with the action and
resource names."""
return set(
- self.appbuilder.get_session.query(self.permission_model)
- .join(self.permission_model.action)
- .join(self.permission_model.resource)
- .with_entities(self.action_model.name, self.resource_model.name)
- .all()
+ self.appbuilder.get_session.execute(
+ select(self.action_model.name, self.resource_model.name)
+ .join(self.permission_model.action)
+ .join(self.permission_model.resource)
+ )
)
def _get_all_non_dag_permissions(self) -> dict[tuple[str, str],
Permission]:
@@ -565,12 +563,12 @@ class AirflowSecurityManager(SecurityManagerOverride,
SecurityManager, LoggingMi
return {
(action_name, resource_name): viewmodel
for action_name, resource_name, viewmodel in (
- self.appbuilder.get_session.query(self.permission_model)
- .join(self.permission_model.action)
- .join(self.permission_model.resource)
-
.filter(~self.resource_model.name.like(f"{permissions.RESOURCE_DAG_PREFIX}%"))
- .with_entities(self.action_model.name,
self.resource_model.name, self.permission_model)
- .all()
+ self.appbuilder.get_session.execute(
+ select(self.action_model.name, self.resource_model.name,
self.permission_model)
+ .join(self.permission_model.action)
+ .join(self.permission_model.resource)
+
.where(~self.resource_model.name.like(f"{permissions.RESOURCE_DAG_PREFIX}%"))
+ )
)
}
@@ -578,9 +576,9 @@ class AirflowSecurityManager(SecurityManagerOverride,
SecurityManager, LoggingMi
"""Returns a dict with a key of role name and value of role with early
loaded permissions."""
return {
r.name: r
- for r in
self.appbuilder.get_session.query(self.role_model).options(
- joinedload(self.role_model.permissions)
- )
+ for r in self.appbuilder.get_session.scalars(
+
select(self.role_model).options(joinedload(self.role_model.permissions))
+ ).unique()
}
def create_dag_specific_permissions(self) -> None:
@@ -621,12 +619,12 @@ class AirflowSecurityManager(SecurityManagerOverride,
SecurityManager, LoggingMi
:return: None.
"""
session = self.appbuilder.get_session
- dag_resources = session.query(Resource).filter(
- Resource.name.like(f"{permissions.RESOURCE_DAG_PREFIX}%")
+ dag_resources = session.scalars(
+
select(Resource).where(Resource.name.like(f"{permissions.RESOURCE_DAG_PREFIX}%"))
)
resource_ids = [resource.id for resource in dag_resources]
- perms =
session.query(Permission).filter(~Permission.resource_id.in_(resource_ids))
+ perms =
session.scalars(select(Permission).where(~Permission.resource_id.in_(resource_ids)))
perms = [p for p in perms if p.action and p.resource]
admin = self.find_role("Admin")
diff --git a/tests/datasets/test_manager.py b/tests/datasets/test_manager.py
index e462d0093e..10d54744b9 100644
--- a/tests/datasets/test_manager.py
+++ b/tests/datasets/test_manager.py
@@ -54,7 +54,7 @@ class TestDatasetManager:
mock_session = mock.Mock()
# Gotta mock up the query results
-
mock_session.query.return_value.filter.return_value.one_or_none.return_value =
None
+ mock_session.scalar.return_value = None
dsem.register_dataset_change(task_instance=mock_task_instance,
dataset=dataset, session=mock_session)
diff --git a/tests/integration/executors/test_celery_executor.py
b/tests/integration/executors/test_celery_executor.py
index 9e8d365c5c..26c2af03db 100644
--- a/tests/integration/executors/test_celery_executor.py
+++ b/tests/integration/executors/test_celery_executor.py
@@ -311,7 +311,7 @@ class TestBulkStateFetcher:
):
caplog.clear()
mock_session = mock_backend.ResultSession.return_value
-
mock_session.query.return_value.filter.return_value.all.return_value = [
+ mock_session.scalars.return_value.all.return_value = [
mock.MagicMock(**{"to_dict.return_value": {"status":
"SUCCESS", "task_id": "123"}})
]
@@ -340,7 +340,7 @@ class TestBulkStateFetcher:
):
caplog.clear()
mock_session = mock_backend.ResultSession.return_value
- mock_retry_db_result =
mock_session.query.return_value.filter.return_value.all
+ mock_retry_db_result = mock_session.scalars.return_value.all
mock_retry_db_result.return_value = [
mock.MagicMock(**{"to_dict.return_value": {"status":
"SUCCESS", "task_id": "123"}})
]