This is an automated email from the ASF dual-hosted git repository.

phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2f44ea5996 Refactor Sqlalchemy queries to 2.0 style (Part 5) (#32474)
2f44ea5996 is described below

commit 2f44ea5996b9664865d6d891537f770caed19a0d
Author: Phani Kumar <[email protected]>
AuthorDate: Wed Jul 12 11:00:24 2023 +0530

    Refactor Sqlalchemy queries to 2.0 style (Part 5) (#32474)
---
 airflow/cli/commands/dag_command.py               | 22 ++++----
 airflow/cli/commands/jobs_command.py              | 11 ++--
 airflow/cli/commands/rotate_fernet_key_command.py |  7 ++-
 airflow/cli/commands/task_command.py              | 21 +++-----
 airflow/cli/commands/variable_command.py          |  6 ++-
 airflow/models/renderedtifields.py                | 15 +++---
 airflow/utils/sqlalchemy.py                       | 64 ++++++++++++++++++++---
 tests/cli/commands/test_task_command.py           |  8 +--
 8 files changed, 103 insertions(+), 51 deletions(-)

diff --git a/airflow/cli/commands/dag_command.py 
b/airflow/cli/commands/dag_command.py
index f57e26b38c..66decad78b 100644
--- a/airflow/cli/commands/dag_command.py
+++ b/airflow/cli/commands/dag_command.py
@@ -28,7 +28,7 @@ import sys
 import warnings
 
 from graphviz.dot import Dot
-from sqlalchemy import delete
+from sqlalchemy import delete, select
 from sqlalchemy.orm import Session
 
 from airflow import settings
@@ -287,7 +287,7 @@ def dag_state(args, session: Session = NEW_SESSION) -> None:
 
     if not dag:
         raise SystemExit(f"DAG: {args.dag_id} does not exist in 'dag' table")
-    dr = session.query(DagRun).filter_by(dag_id=args.dag_id, 
execution_date=args.execution_date).one_or_none()
+    dr = session.scalar(select(DagRun).filter_by(dag_id=args.dag_id, 
execution_date=args.execution_date))
     out = dr.state if dr else None
     conf_out = ""
     if out and dr.conf:
@@ -309,7 +309,9 @@ def dag_next_execution(args) -> None:
         print("[INFO] Please be reminded this DAG is PAUSED now.", 
file=sys.stderr)
 
     with create_session() as session:
-        last_parsed_dag: DagModel = 
session.query(DagModel).filter(DagModel.dag_id == dag.dag_id).one()
+        last_parsed_dag: DagModel = session.scalars(
+            select(DagModel).where(DagModel.dag_id == dag.dag_id)
+        ).one()
 
     def print_execution_interval(interval: DataInterval | None):
         if interval is None:
@@ -428,8 +430,10 @@ def dag_list_jobs(args, dag: DAG | None = None, session: 
Session = NEW_SESSION)
         queries.append(Job.state == args.state)
 
     fields = ["dag_id", "state", "job_type", "start_date", "end_date"]
-    all_jobs = 
session.query(Job).filter(*queries).order_by(Job.start_date.desc()).limit(args.limit).all()
-    all_jobs = [{f: str(job.__getattribute__(f)) for f in fields} for job in 
all_jobs]
+    all_jobs_iter = session.scalars(
+        
select(Job).where(*queries).order_by(Job.start_date.desc()).limit(args.limit)
+    )
+    all_jobs = [{f: str(job.__getattribute__(f)) for f in fields} for job in 
all_jobs_iter]
 
     AirflowConsole().print_as(
         data=all_jobs,
@@ -492,14 +496,12 @@ def dag_test(args, dag: DAG | None = None, session: 
Session = NEW_SESSION) -> No
     imgcat = args.imgcat_dagrun
     filename = args.save_dagrun
     if show_dagrun or imgcat or filename:
-        tis = (
-            session.query(TaskInstance)
-            .filter(
+        tis = session.scalars(
+            select(TaskInstance).where(
                 TaskInstance.dag_id == args.dag_id,
                 TaskInstance.execution_date == execution_date,
             )
-            .all()
-        )
+        ).all()
 
         dot_graph = render_dag(dag, tis=tis)
         print()
diff --git a/airflow/cli/commands/jobs_command.py 
b/airflow/cli/commands/jobs_command.py
index 79e0aad79f..339896d877 100644
--- a/airflow/cli/commands/jobs_command.py
+++ b/airflow/cli/commands/jobs_command.py
@@ -16,6 +16,7 @@
 # under the License.
 from __future__ import annotations
 
+from sqlalchemy import select
 from sqlalchemy.orm import Session
 
 from airflow.jobs.job import Job
@@ -32,17 +33,17 @@ def check(args, session: Session = NEW_SESSION) -> None:
     if args.hostname and args.local:
         raise SystemExit("You can't use --hostname and --local at the same 
time")
 
-    query = session.query(Job).filter(Job.state == 
State.RUNNING).order_by(Job.latest_heartbeat.desc())
+    query = select(Job).where(Job.state == 
State.RUNNING).order_by(Job.latest_heartbeat.desc())
     if args.job_type:
-        query = query.filter(Job.job_type == args.job_type)
+        query = query.where(Job.job_type == args.job_type)
     if args.hostname:
-        query = query.filter(Job.hostname == args.hostname)
+        query = query.where(Job.hostname == args.hostname)
     if args.local:
-        query = query.filter(Job.hostname == get_hostname())
+        query = query.where(Job.hostname == get_hostname())
     if args.limit > 0:
         query = query.limit(args.limit)
 
-    alive_jobs: list[Job] = [job for job in query.all() if job.is_alive()]
+    alive_jobs: list[Job] = [job for job in session.scalars(query) if 
job.is_alive()]
 
     count_alive_jobs = len(alive_jobs)
     if count_alive_jobs == 0:
diff --git a/airflow/cli/commands/rotate_fernet_key_command.py 
b/airflow/cli/commands/rotate_fernet_key_command.py
index f9e1873597..e9973978e0 100644
--- a/airflow/cli/commands/rotate_fernet_key_command.py
+++ b/airflow/cli/commands/rotate_fernet_key_command.py
@@ -17,6 +17,8 @@
 """Rotate Fernet key command."""
 from __future__ import annotations
 
+from sqlalchemy import select
+
 from airflow.models import Connection, Variable
 from airflow.utils import cli as cli_utils
 from airflow.utils.session import create_session
@@ -26,7 +28,8 @@ from airflow.utils.session import create_session
 def rotate_fernet_key(args):
     """Rotates all encrypted connection credentials and variables."""
     with create_session() as session:
-        for conn in session.query(Connection).filter(Connection.is_encrypted | 
Connection.is_extra_encrypted):
+        conns_query = select(Connection).where(Connection.is_encrypted | 
Connection.is_extra_encrypted)
+        for conn in session.scalars(conns_query):
             conn.rotate_fernet_key()
-        for var in session.query(Variable).filter(Variable.is_encrypted):
+        for var in 
session.scalars(select(Variable).where(Variable.is_encrypted)):
             var.rotate_fernet_key()
diff --git a/airflow/cli/commands/task_command.py 
b/airflow/cli/commands/task_command.py
index a5e20e8751..796d3a8cf1 100644
--- a/airflow/cli/commands/task_command.py
+++ b/airflow/cli/commands/task_command.py
@@ -29,6 +29,7 @@ from typing import Generator, Union, cast
 
 import pendulum
 from pendulum.parsing.exceptions import ParserError
+from sqlalchemy import select
 from sqlalchemy.orm.exc import NoResultFound
 from sqlalchemy.orm.session import Session
 
@@ -111,11 +112,9 @@ def _get_dag_run(
         with suppress(ParserError, TypeError):
             execution_date = timezone.parse(exec_date_or_run_id)
         try:
-            dag_run = (
-                session.query(DagRun)
-                .filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date == 
execution_date)
-                .one()
-            )
+            dag_run = session.scalars(
+                select(DagRun).where(DagRun.dag_id == dag.dag_id, 
DagRun.execution_date == execution_date)
+            ).one()
         except NoResultFound:
             if not create_if_necessary:
                 raise DagRunNotFound(
@@ -534,18 +533,14 @@ def _guess_debugger() -> _SupportedDebugger:
 @provide_session
 def task_states_for_dag_run(args, session: Session = NEW_SESSION) -> None:
     """Get the status of all task instances in a DagRun."""
-    dag_run = (
-        session.query(DagRun)
-        .filter(DagRun.run_id == args.execution_date_or_run_id, DagRun.dag_id 
== args.dag_id)
-        .one_or_none()
+    dag_run = session.scalar(
+        select(DagRun).where(DagRun.run_id == args.execution_date_or_run_id, 
DagRun.dag_id == args.dag_id)
     )
     if not dag_run:
         try:
             execution_date = timezone.parse(args.execution_date_or_run_id)
-            dag_run = (
-                session.query(DagRun)
-                .filter(DagRun.execution_date == execution_date, DagRun.dag_id 
== args.dag_id)
-                .one_or_none()
+            dag_run = session.scalar(
+                select(DagRun).where(DagRun.execution_date == execution_date, 
DagRun.dag_id == args.dag_id)
             )
         except (ParserError, TypeError) as err:
             raise AirflowException(f"Error parsing the supplied 
execution_date. Error: {str(err)}")
diff --git a/airflow/cli/commands/variable_command.py 
b/airflow/cli/commands/variable_command.py
index 009b4704aa..32f6b0c198 100644
--- a/airflow/cli/commands/variable_command.py
+++ b/airflow/cli/commands/variable_command.py
@@ -22,6 +22,8 @@ import json
 import os
 from json import JSONDecodeError
 
+from sqlalchemy import select
+
 from airflow.cli.simple_table import AirflowConsole
 from airflow.models import Variable
 from airflow.utils import cli as cli_utils
@@ -33,7 +35,7 @@ from airflow.utils.session import create_session
 def variables_list(args):
     """Displays all the variables."""
     with create_session() as session:
-        variables = session.query(Variable)
+        variables = session.scalars(select(Variable)).all()
     AirflowConsole().print_as(data=variables, output=args.output, 
mapper=lambda x: {"key": x.key})
 
 
@@ -107,7 +109,7 @@ def _variable_export_helper(filepath):
     """Helps export all the variables to the file."""
     var_dict = {}
     with create_session() as session:
-        qry = session.query(Variable).all()
+        qry = session.scalars(select(Variable))
 
         data = json.JSONDecoder()
         for var in qry:
diff --git a/airflow/models/renderedtifields.py 
b/airflow/models/renderedtifields.py
index 9586834470..37a5f08b27 100644
--- a/airflow/models/renderedtifields.py
+++ b/airflow/models/renderedtifields.py
@@ -134,15 +134,13 @@ class RenderedTaskInstanceFields(Base):
         :param session: SqlAlchemy Session
         :return: Rendered Templated TI field
         """
-        result = (
-            session.query(cls.rendered_fields)
-            .filter(
+        result = session.scalar(
+            select(cls).where(
                 cls.dag_id == ti.dag_id,
                 cls.task_id == ti.task_id,
                 cls.run_id == ti.run_id,
                 cls.map_index == ti.map_index,
             )
-            .one_or_none()
         )
 
         if result:
@@ -162,15 +160,13 @@ class RenderedTaskInstanceFields(Base):
         :param session: SqlAlchemy Session
         :return: Kubernetes Pod Yaml
         """
-        result = (
-            session.query(cls.k8s_pod_yaml)
-            .filter(
+        result = session.scalar(
+            select(cls).where(
                 cls.dag_id == ti.dag_id,
                 cls.task_id == ti.task_id,
                 cls.run_id == ti.run_id,
                 cls.map_index == ti.map_index,
             )
-            .one_or_none()
         )
         return result.k8s_pod_yaml if result else None
 
@@ -243,7 +239,8 @@ class RenderedTaskInstanceFields(Base):
                 cls.task_id == task_id,
                 tuple_not_in_condition(
                     (cls.dag_id, cls.task_id, cls.run_id),
-                    session.query(ti_clause.c.dag_id, ti_clause.c.task_id, 
ti_clause.c.run_id),
+                    select(ti_clause.c.dag_id, ti_clause.c.task_id, 
ti_clause.c.run_id),
+                    session=session,
                 ),
             )
             .execution_options(synchronize_session=False)
diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py
index 32a5a796d5..04af5bddc3 100644
--- a/airflow/utils/sqlalchemy.py
+++ b/airflow/utils/sqlalchemy.py
@@ -22,14 +22,14 @@ import copy
 import datetime
 import json
 import logging
-from typing import TYPE_CHECKING, Any, Generator, Iterable
+from typing import TYPE_CHECKING, Any, Generator, Iterable, overload
 
 import pendulum
 from dateutil import relativedelta
 from sqlalchemy import TIMESTAMP, PickleType, and_, event, false, nullsfirst, 
or_, true, tuple_
 from sqlalchemy.dialects import mssql, mysql
 from sqlalchemy.exc import OperationalError
-from sqlalchemy.sql import ColumnElement
+from sqlalchemy.sql import ColumnElement, Select
 from sqlalchemy.sql.expression import ColumnOperators
 from sqlalchemy.types import JSON, Text, TypeDecorator, TypeEngine, UnicodeText
 
@@ -515,11 +515,31 @@ def is_lock_not_available_error(error: OperationalError):
     return False
 
 
+@overload
 def tuple_in_condition(
     columns: tuple[ColumnElement, ...],
     collection: Iterable[Any],
 ) -> ColumnOperators:
-    """Generates a tuple-in-collection operator to use in ``.filter()``.
+    ...
+
+
+@overload
+def tuple_in_condition(
+    columns: tuple[ColumnElement, ...],
+    collection: Select,
+    *,
+    session: Session,
+) -> ColumnOperators:
+    ...
+
+
+def tuple_in_condition(
+    columns: tuple[ColumnElement, ...],
+    collection: Iterable[Any] | Select,
+    *,
+    session: Session | None = None,
+) -> ColumnOperators:
+    """Generates a tuple-in-collection operator to use in ``.where()``.
 
     For most SQL backends, this generates a simple ``([col, ...]) IN 
[condition]``
     clause. This however does not work with MSSQL, where we need to expand to
@@ -529,17 +549,43 @@ def tuple_in_condition(
     """
     if settings.engine.dialect.name != "mssql":
         return tuple_(*columns).in_(collection)
-    clauses = [and_(*(c == v for c, v in zip(columns, values))) for values in 
collection]
+    if not isinstance(collection, Select):
+        rows = collection
+    elif session is None:
+        raise TypeError("session is required when passing in a subquery")
+    else:
+        rows = session.execute(collection)
+    clauses = [and_(*(c == v for c, v in zip(columns, values))) for values in 
rows]
     if not clauses:
         return false()
     return or_(*clauses)
 
 
+@overload
 def tuple_not_in_condition(
     columns: tuple[ColumnElement, ...],
     collection: Iterable[Any],
 ) -> ColumnOperators:
-    """Generates a tuple-not-in-collection operator to use in ``.filter()``.
+    ...
+
+
+@overload
+def tuple_not_in_condition(
+    columns: tuple[ColumnElement, ...],
+    collection: Select,
+    *,
+    session: Session,
+) -> ColumnOperators:
+    ...
+
+
+def tuple_not_in_condition(
+    columns: tuple[ColumnElement, ...],
+    collection: Iterable[Any] | Select,
+    *,
+    session: Session | None = None,
+) -> ColumnOperators:
+    """Generates a tuple-not-in-collection operator to use in ``.where()``.
 
     This is similar to ``tuple_in_condition`` except generating ``NOT IN``.
 
@@ -547,7 +593,13 @@ def tuple_not_in_condition(
     """
     if settings.engine.dialect.name != "mssql":
         return tuple_(*columns).not_in(collection)
-    clauses = [or_(*(c != v for c, v in zip(columns, values))) for values in 
collection]
+    if not isinstance(collection, Select):
+        rows = collection
+    elif session is None:
+        raise TypeError("session is required when passing in a subquery")
+    else:
+        rows = session.execute(collection)
+    clauses = [or_(*(c != v for c, v in zip(columns, values))) for values in 
rows]
     if not clauses:
         return true()
     return and_(*clauses)
diff --git a/tests/cli/commands/test_task_command.py 
b/tests/cli/commands/test_task_command.py
index 7c2eb61d32..74a7dcea0d 100644
--- a/tests/cli/commands/test_task_command.py
+++ b/tests/cli/commands/test_task_command.py
@@ -500,15 +500,15 @@ class TestCliTasks:
         assert 'echo "2022-01-01"' in output
         assert 'echo "2022-01-08"' in output
 
-    @mock.patch("sqlalchemy.orm.session.Session.query")
+    @mock.patch("airflow.cli.commands.task_command.select")
+    @mock.patch("airflow.cli.commands.task_command.Session.scalars")
     @mock.patch("airflow.cli.commands.task_command.DagRun")
-    def test_task_render_with_custom_timetable(self, mock_dagrun, mock_query):
+    def test_task_render_with_custom_timetable(self, mock_dagrun, 
mock_scalars, mock_select):
         """
         when calling `tasks render` on dag with custom timetable, the DagRun 
object should be created with
          data_intervals.
         """
-        mock_query.side_effect = sqlalchemy.exc.NoResultFound
-
+        mock_scalars.side_effect = sqlalchemy.exc.NoResultFound
         task_command.task_render(
             self.parser.parse_args(["tasks", "render", 
"example_workday_timetable", "run_this", "2022-01-01"])
         )

Reply via email to