This is an automated email from the ASF dual-hosted git repository.
phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 2f44ea5996 Refactor Sqlalchemy queries to 2.0 style (Part 5) (#32474)
2f44ea5996 is described below
commit 2f44ea5996b9664865d6d891537f770caed19a0d
Author: Phani Kumar <[email protected]>
AuthorDate: Wed Jul 12 11:00:24 2023 +0530
Refactor Sqlalchemy queries to 2.0 style (Part 5) (#32474)
---
airflow/cli/commands/dag_command.py | 22 ++++----
airflow/cli/commands/jobs_command.py | 11 ++--
airflow/cli/commands/rotate_fernet_key_command.py | 7 ++-
airflow/cli/commands/task_command.py | 21 +++-----
airflow/cli/commands/variable_command.py | 6 ++-
airflow/models/renderedtifields.py | 15 +++---
airflow/utils/sqlalchemy.py | 64 ++++++++++++++++++++---
tests/cli/commands/test_task_command.py | 8 +--
8 files changed, 103 insertions(+), 51 deletions(-)
diff --git a/airflow/cli/commands/dag_command.py
b/airflow/cli/commands/dag_command.py
index f57e26b38c..66decad78b 100644
--- a/airflow/cli/commands/dag_command.py
+++ b/airflow/cli/commands/dag_command.py
@@ -28,7 +28,7 @@ import sys
import warnings
from graphviz.dot import Dot
-from sqlalchemy import delete
+from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from airflow import settings
@@ -287,7 +287,7 @@ def dag_state(args, session: Session = NEW_SESSION) -> None:
if not dag:
raise SystemExit(f"DAG: {args.dag_id} does not exist in 'dag' table")
- dr = session.query(DagRun).filter_by(dag_id=args.dag_id,
execution_date=args.execution_date).one_or_none()
+ dr = session.scalar(select(DagRun).filter_by(dag_id=args.dag_id,
execution_date=args.execution_date))
out = dr.state if dr else None
conf_out = ""
if out and dr.conf:
@@ -309,7 +309,9 @@ def dag_next_execution(args) -> None:
print("[INFO] Please be reminded this DAG is PAUSED now.",
file=sys.stderr)
with create_session() as session:
- last_parsed_dag: DagModel =
session.query(DagModel).filter(DagModel.dag_id == dag.dag_id).one()
+ last_parsed_dag: DagModel = session.scalars(
+ select(DagModel).where(DagModel.dag_id == dag.dag_id)
+ ).one()
def print_execution_interval(interval: DataInterval | None):
if interval is None:
@@ -428,8 +430,10 @@ def dag_list_jobs(args, dag: DAG | None = None, session:
Session = NEW_SESSION)
queries.append(Job.state == args.state)
fields = ["dag_id", "state", "job_type", "start_date", "end_date"]
- all_jobs =
session.query(Job).filter(*queries).order_by(Job.start_date.desc()).limit(args.limit).all()
- all_jobs = [{f: str(job.__getattribute__(f)) for f in fields} for job in
all_jobs]
+ all_jobs_iter = session.scalars(
+
select(Job).where(*queries).order_by(Job.start_date.desc()).limit(args.limit)
+ )
+ all_jobs = [{f: str(job.__getattribute__(f)) for f in fields} for job in
all_jobs_iter]
AirflowConsole().print_as(
data=all_jobs,
@@ -492,14 +496,12 @@ def dag_test(args, dag: DAG | None = None, session:
Session = NEW_SESSION) -> No
imgcat = args.imgcat_dagrun
filename = args.save_dagrun
if show_dagrun or imgcat or filename:
- tis = (
- session.query(TaskInstance)
- .filter(
+ tis = session.scalars(
+ select(TaskInstance).where(
TaskInstance.dag_id == args.dag_id,
TaskInstance.execution_date == execution_date,
)
- .all()
- )
+ ).all()
dot_graph = render_dag(dag, tis=tis)
print()
diff --git a/airflow/cli/commands/jobs_command.py
b/airflow/cli/commands/jobs_command.py
index 79e0aad79f..339896d877 100644
--- a/airflow/cli/commands/jobs_command.py
+++ b/airflow/cli/commands/jobs_command.py
@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
+from sqlalchemy import select
from sqlalchemy.orm import Session
from airflow.jobs.job import Job
@@ -32,17 +33,17 @@ def check(args, session: Session = NEW_SESSION) -> None:
if args.hostname and args.local:
raise SystemExit("You can't use --hostname and --local at the same
time")
- query = session.query(Job).filter(Job.state ==
State.RUNNING).order_by(Job.latest_heartbeat.desc())
+ query = select(Job).where(Job.state ==
State.RUNNING).order_by(Job.latest_heartbeat.desc())
if args.job_type:
- query = query.filter(Job.job_type == args.job_type)
+ query = query.where(Job.job_type == args.job_type)
if args.hostname:
- query = query.filter(Job.hostname == args.hostname)
+ query = query.where(Job.hostname == args.hostname)
if args.local:
- query = query.filter(Job.hostname == get_hostname())
+ query = query.where(Job.hostname == get_hostname())
if args.limit > 0:
query = query.limit(args.limit)
- alive_jobs: list[Job] = [job for job in query.all() if job.is_alive()]
+ alive_jobs: list[Job] = [job for job in session.scalars(query) if
job.is_alive()]
count_alive_jobs = len(alive_jobs)
if count_alive_jobs == 0:
diff --git a/airflow/cli/commands/rotate_fernet_key_command.py
b/airflow/cli/commands/rotate_fernet_key_command.py
index f9e1873597..e9973978e0 100644
--- a/airflow/cli/commands/rotate_fernet_key_command.py
+++ b/airflow/cli/commands/rotate_fernet_key_command.py
@@ -17,6 +17,8 @@
"""Rotate Fernet key command."""
from __future__ import annotations
+from sqlalchemy import select
+
from airflow.models import Connection, Variable
from airflow.utils import cli as cli_utils
from airflow.utils.session import create_session
@@ -26,7 +28,8 @@ from airflow.utils.session import create_session
def rotate_fernet_key(args):
"""Rotates all encrypted connection credentials and variables."""
with create_session() as session:
- for conn in session.query(Connection).filter(Connection.is_encrypted |
Connection.is_extra_encrypted):
+ conns_query = select(Connection).where(Connection.is_encrypted |
Connection.is_extra_encrypted)
+ for conn in session.scalars(conns_query):
conn.rotate_fernet_key()
- for var in session.query(Variable).filter(Variable.is_encrypted):
+ for var in
session.scalars(select(Variable).where(Variable.is_encrypted)):
var.rotate_fernet_key()
diff --git a/airflow/cli/commands/task_command.py
b/airflow/cli/commands/task_command.py
index a5e20e8751..796d3a8cf1 100644
--- a/airflow/cli/commands/task_command.py
+++ b/airflow/cli/commands/task_command.py
@@ -29,6 +29,7 @@ from typing import Generator, Union, cast
import pendulum
from pendulum.parsing.exceptions import ParserError
+from sqlalchemy import select
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.orm.session import Session
@@ -111,11 +112,9 @@ def _get_dag_run(
with suppress(ParserError, TypeError):
execution_date = timezone.parse(exec_date_or_run_id)
try:
- dag_run = (
- session.query(DagRun)
- .filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date ==
execution_date)
- .one()
- )
+ dag_run = session.scalars(
+ select(DagRun).where(DagRun.dag_id == dag.dag_id,
DagRun.execution_date == execution_date)
+ ).one()
except NoResultFound:
if not create_if_necessary:
raise DagRunNotFound(
@@ -534,18 +533,14 @@ def _guess_debugger() -> _SupportedDebugger:
@provide_session
def task_states_for_dag_run(args, session: Session = NEW_SESSION) -> None:
"""Get the status of all task instances in a DagRun."""
- dag_run = (
- session.query(DagRun)
- .filter(DagRun.run_id == args.execution_date_or_run_id, DagRun.dag_id
== args.dag_id)
- .one_or_none()
+ dag_run = session.scalar(
+ select(DagRun).where(DagRun.run_id == args.execution_date_or_run_id,
DagRun.dag_id == args.dag_id)
)
if not dag_run:
try:
execution_date = timezone.parse(args.execution_date_or_run_id)
- dag_run = (
- session.query(DagRun)
- .filter(DagRun.execution_date == execution_date, DagRun.dag_id
== args.dag_id)
- .one_or_none()
+ dag_run = session.scalar(
+ select(DagRun).where(DagRun.execution_date == execution_date,
DagRun.dag_id == args.dag_id)
)
except (ParserError, TypeError) as err:
raise AirflowException(f"Error parsing the supplied
execution_date. Error: {str(err)}")
diff --git a/airflow/cli/commands/variable_command.py
b/airflow/cli/commands/variable_command.py
index 009b4704aa..32f6b0c198 100644
--- a/airflow/cli/commands/variable_command.py
+++ b/airflow/cli/commands/variable_command.py
@@ -22,6 +22,8 @@ import json
import os
from json import JSONDecodeError
+from sqlalchemy import select
+
from airflow.cli.simple_table import AirflowConsole
from airflow.models import Variable
from airflow.utils import cli as cli_utils
@@ -33,7 +35,7 @@ from airflow.utils.session import create_session
def variables_list(args):
"""Displays all the variables."""
with create_session() as session:
- variables = session.query(Variable)
+ variables = session.scalars(select(Variable)).all()
AirflowConsole().print_as(data=variables, output=args.output,
mapper=lambda x: {"key": x.key})
@@ -107,7 +109,7 @@ def _variable_export_helper(filepath):
"""Helps export all the variables to the file."""
var_dict = {}
with create_session() as session:
- qry = session.query(Variable).all()
+ qry = session.scalars(select(Variable))
data = json.JSONDecoder()
for var in qry:
diff --git a/airflow/models/renderedtifields.py
b/airflow/models/renderedtifields.py
index 9586834470..37a5f08b27 100644
--- a/airflow/models/renderedtifields.py
+++ b/airflow/models/renderedtifields.py
@@ -134,15 +134,13 @@ class RenderedTaskInstanceFields(Base):
:param session: SqlAlchemy Session
:return: Rendered Templated TI field
"""
- result = (
- session.query(cls.rendered_fields)
- .filter(
+ result = session.scalar(
+ select(cls).where(
cls.dag_id == ti.dag_id,
cls.task_id == ti.task_id,
cls.run_id == ti.run_id,
cls.map_index == ti.map_index,
)
- .one_or_none()
)
if result:
@@ -162,15 +160,13 @@ class RenderedTaskInstanceFields(Base):
:param session: SqlAlchemy Session
:return: Kubernetes Pod Yaml
"""
- result = (
- session.query(cls.k8s_pod_yaml)
- .filter(
+ result = session.scalar(
+ select(cls).where(
cls.dag_id == ti.dag_id,
cls.task_id == ti.task_id,
cls.run_id == ti.run_id,
cls.map_index == ti.map_index,
)
- .one_or_none()
)
return result.k8s_pod_yaml if result else None
@@ -243,7 +239,8 @@ class RenderedTaskInstanceFields(Base):
cls.task_id == task_id,
tuple_not_in_condition(
(cls.dag_id, cls.task_id, cls.run_id),
- session.query(ti_clause.c.dag_id, ti_clause.c.task_id,
ti_clause.c.run_id),
+ select(ti_clause.c.dag_id, ti_clause.c.task_id,
ti_clause.c.run_id),
+ session=session,
),
)
.execution_options(synchronize_session=False)
diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py
index 32a5a796d5..04af5bddc3 100644
--- a/airflow/utils/sqlalchemy.py
+++ b/airflow/utils/sqlalchemy.py
@@ -22,14 +22,14 @@ import copy
import datetime
import json
import logging
-from typing import TYPE_CHECKING, Any, Generator, Iterable
+from typing import TYPE_CHECKING, Any, Generator, Iterable, overload
import pendulum
from dateutil import relativedelta
from sqlalchemy import TIMESTAMP, PickleType, and_, event, false, nullsfirst,
or_, true, tuple_
from sqlalchemy.dialects import mssql, mysql
from sqlalchemy.exc import OperationalError
-from sqlalchemy.sql import ColumnElement
+from sqlalchemy.sql import ColumnElement, Select
from sqlalchemy.sql.expression import ColumnOperators
from sqlalchemy.types import JSON, Text, TypeDecorator, TypeEngine, UnicodeText
@@ -515,11 +515,31 @@ def is_lock_not_available_error(error: OperationalError):
return False
+@overload
def tuple_in_condition(
columns: tuple[ColumnElement, ...],
collection: Iterable[Any],
) -> ColumnOperators:
- """Generates a tuple-in-collection operator to use in ``.filter()``.
+ ...
+
+
+@overload
+def tuple_in_condition(
+ columns: tuple[ColumnElement, ...],
+ collection: Select,
+ *,
+ session: Session,
+) -> ColumnOperators:
+ ...
+
+
+def tuple_in_condition(
+ columns: tuple[ColumnElement, ...],
+ collection: Iterable[Any] | Select,
+ *,
+ session: Session | None = None,
+) -> ColumnOperators:
+ """Generates a tuple-in-collection operator to use in ``.where()``.
For most SQL backends, this generates a simple ``([col, ...]) IN
[condition]``
clause. This however does not work with MSSQL, where we need to expand to
@@ -529,17 +549,43 @@ def tuple_in_condition(
"""
if settings.engine.dialect.name != "mssql":
return tuple_(*columns).in_(collection)
- clauses = [and_(*(c == v for c, v in zip(columns, values))) for values in
collection]
+ if not isinstance(collection, Select):
+ rows = collection
+ elif session is None:
+ raise TypeError("session is required when passing in a subquery")
+ else:
+ rows = session.execute(collection)
+ clauses = [and_(*(c == v for c, v in zip(columns, values))) for values in
rows]
if not clauses:
return false()
return or_(*clauses)
+@overload
def tuple_not_in_condition(
columns: tuple[ColumnElement, ...],
collection: Iterable[Any],
) -> ColumnOperators:
- """Generates a tuple-not-in-collection operator to use in ``.filter()``.
+ ...
+
+
+@overload
+def tuple_not_in_condition(
+ columns: tuple[ColumnElement, ...],
+ collection: Select,
+ *,
+ session: Session,
+) -> ColumnOperators:
+ ...
+
+
+def tuple_not_in_condition(
+ columns: tuple[ColumnElement, ...],
+ collection: Iterable[Any] | Select,
+ *,
+ session: Session | None = None,
+) -> ColumnOperators:
+ """Generates a tuple-not-in-collection operator to use in ``.where()``.
This is similar to ``tuple_in_condition`` except generating ``NOT IN``.
@@ -547,7 +593,13 @@ def tuple_not_in_condition(
"""
if settings.engine.dialect.name != "mssql":
return tuple_(*columns).not_in(collection)
- clauses = [or_(*(c != v for c, v in zip(columns, values))) for values in
collection]
+ if not isinstance(collection, Select):
+ rows = collection
+ elif session is None:
+ raise TypeError("session is required when passing in a subquery")
+ else:
+ rows = session.execute(collection)
+ clauses = [or_(*(c != v for c, v in zip(columns, values))) for values in
rows]
if not clauses:
return true()
return and_(*clauses)
diff --git a/tests/cli/commands/test_task_command.py
b/tests/cli/commands/test_task_command.py
index 7c2eb61d32..74a7dcea0d 100644
--- a/tests/cli/commands/test_task_command.py
+++ b/tests/cli/commands/test_task_command.py
@@ -500,15 +500,15 @@ class TestCliTasks:
assert 'echo "2022-01-01"' in output
assert 'echo "2022-01-08"' in output
- @mock.patch("sqlalchemy.orm.session.Session.query")
+ @mock.patch("airflow.cli.commands.task_command.select")
+ @mock.patch("airflow.cli.commands.task_command.Session.scalars")
@mock.patch("airflow.cli.commands.task_command.DagRun")
- def test_task_render_with_custom_timetable(self, mock_dagrun, mock_query):
+ def test_task_render_with_custom_timetable(self, mock_dagrun,
mock_scalars, mock_select):
"""
when calling `tasks render` on dag with custom timetable, the DagRun
object should be created with
data_intervals.
"""
- mock_query.side_effect = sqlalchemy.exc.NoResultFound
-
+ mock_scalars.side_effect = sqlalchemy.exc.NoResultFound
task_command.task_render(
self.parser.parse_args(["tasks", "render",
"example_workday_timetable", "run_this", "2022-01-01"])
)