This is an automated email from the ASF dual-hosted git repository.

phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 26bb408d17 Refactor Sqlalchemy queries to 2.0 style (Part 4) (#32339)
26bb408d17 is described below

commit 26bb408d17d6eecb6f67e31e536f052a37a66db8
Author: Phani Kumar <[email protected]>
AuthorDate: Fri Jul 7 16:52:06 2023 +0530

    Refactor Sqlalchemy queries to 2.0 style (Part 4) (#32339)
---
 airflow/cli/commands/connection_command.py    | 14 ++++++++------
 airflow/secrets/metastore.py                  |  5 +++--
 airflow/www/utils.py                          | 17 +++++++++--------
 tests/cli/commands/test_connection_command.py |  4 +---
 4 files changed, 21 insertions(+), 19 deletions(-)

diff --git a/airflow/cli/commands/connection_command.py 
b/airflow/cli/commands/connection_command.py
index 63888c0ec9..e7b83e342e 100644
--- a/airflow/cli/commands/connection_command.py
+++ b/airflow/cli/commands/connection_command.py
@@ -26,6 +26,7 @@ from pathlib import Path
 from typing import Any
 from urllib.parse import urlsplit, urlunsplit
 
+from sqlalchemy import select
 from sqlalchemy.orm import exc
 
 from airflow.cli.simple_table import AirflowConsole
@@ -77,9 +78,10 @@ def connections_get(args):
 def connections_list(args):
     """Lists all connections at the command line."""
     with create_session() as session:
-        query = session.query(Connection)
+        query = select(Connection)
         if args.conn_id:
-            query = query.filter(Connection.conn_id == args.conn_id)
+            query = query.where(Connection.conn_id == args.conn_id)
+        query = session.scalars(query)
         conns = query.all()
 
         AirflowConsole().print_as(
@@ -177,7 +179,7 @@ def connections_export(args):
         raise SystemExit("Option `--serialization-format` may only be used 
with file type `env`.")
 
     with create_session() as session:
-        connections = 
session.query(Connection).order_by(Connection.conn_id).all()
+        connections = 
session.scalars(select(Connection).order_by(Connection.conn_id)).all()
 
     msg = _format_connections(
         conns=connections,
@@ -265,7 +267,7 @@ def connections_add(args):
             new_conn.set_extra(args.conn_extra)
 
     with create_session() as session:
-        if not session.query(Connection).filter(Connection.conn_id == 
new_conn.conn_id).first():
+        if not session.scalar(select(Connection).where(Connection.conn_id == 
new_conn.conn_id).limit(1)):
             session.add(new_conn)
             msg = "Successfully added `conn_id`={conn_id} : {uri}"
             msg = msg.format(
@@ -293,7 +295,7 @@ def connections_delete(args):
     """Deletes connection from DB."""
     with create_session() as session:
         try:
-            to_delete = session.query(Connection).filter(Connection.conn_id == 
args.conn_id).one()
+            to_delete = 
session.scalars(select(Connection).where(Connection.conn_id == 
args.conn_id)).one()
         except exc.NoResultFound:
             raise SystemExit(f"Did not find a connection with 
`conn_id`={args.conn_id}")
         except exc.MultipleResultsFound:
@@ -326,7 +328,7 @@ def _import_helper(file_path: str, overwrite: bool) -> None:
                 print(f"Could not import connection. {e}")
                 continue
 
-            existing_conn_id = 
session.query(Connection.id).filter(Connection.conn_id == conn_id).scalar()
+            existing_conn_id = 
session.scalar(select(Connection.id).where(Connection.conn_id == conn_id))
             if existing_conn_id is not None:
                 if not overwrite:
                     print(f"Could not import connection {conn_id}: connection 
already exists.")
diff --git a/airflow/secrets/metastore.py b/airflow/secrets/metastore.py
index dc81675d77..160c063dff 100644
--- a/airflow/secrets/metastore.py
+++ b/airflow/secrets/metastore.py
@@ -21,6 +21,7 @@ from __future__ import annotations
 import warnings
 from typing import TYPE_CHECKING
 
+from sqlalchemy import select
 from sqlalchemy.orm import Session
 
 from airflow.exceptions import RemovedInAirflow3Warning
@@ -38,7 +39,7 @@ class MetastoreBackend(BaseSecretsBackend):
     def get_connection(self, conn_id: str, session: Session = NEW_SESSION) -> 
Connection | None:
         from airflow.models.connection import Connection
 
-        conn = session.query(Connection).filter(Connection.conn_id == 
conn_id).first()
+        conn = session.scalar(select(Connection).where(Connection.conn_id == 
conn_id).limit(1))
         session.expunge_all()
         return conn
 
@@ -65,7 +66,7 @@ class MetastoreBackend(BaseSecretsBackend):
         """
         from airflow.models.variable import Variable
 
-        var_value = session.query(Variable).filter(Variable.key == key).first()
+        var_value = session.scalar(select(Variable).where(Variable.key == 
key).limit(1))
         session.expunge_all()
         if var_value:
             return var_value.val
diff --git a/airflow/www/utils.py b/airflow/www/utils.py
index 4aaeda595b..1ca6289122 100644
--- a/airflow/www/utils.py
+++ b/airflow/www/utils.py
@@ -37,7 +37,7 @@ from pendulum.datetime import DateTime
 from pygments import highlight, lexers
 from pygments.formatters import HtmlFormatter
 from pygments.lexer import Lexer
-from sqlalchemy import delete, func, types
+from sqlalchemy import delete, func, select, types
 from sqlalchemy.ext.associationproxy import AssociationProxy
 from sqlalchemy.sql import Select
 
@@ -68,16 +68,15 @@ def datetime_to_string(value: DateTime | None) -> str | 
None:
 
 
 def get_mapped_instances(task_instance, session):
-    return (
-        session.query(TaskInstance)
-        .filter(
+    return session.scalars(
+        select(TaskInstance)
+        .where(
             TaskInstance.dag_id == task_instance.dag_id,
             TaskInstance.run_id == task_instance.run_id,
             TaskInstance.task_id == task_instance.task_id,
         )
         .order_by(TaskInstance.map_index)
-        .all()
-    )
+    ).all()
 
 
 def get_instance_with_map(task_instance, session):
@@ -179,14 +178,16 @@ def encode_dag_run(
 
 def check_import_errors(fileloc, session):
     # Check dag import errors
-    import_errors = 
session.query(errors.ImportError).filter(errors.ImportError.filename == 
fileloc).all()
+    import_errors = session.scalars(
+        select(errors.ImportError).where(errors.ImportError.filename == 
fileloc)
+    ).all()
     if import_errors:
         for import_error in import_errors:
             flash("Broken DAG: [{ie.filename}] 
{ie.stacktrace}".format(ie=import_error), "dag_import_error")
 
 
 def check_dag_warnings(dag_id, session):
-    dag_warnings = session.query(DagWarning).filter(DagWarning.dag_id == 
dag_id).all()
+    dag_warnings = session.scalars(select(DagWarning).where(DagWarning.dag_id 
== dag_id)).all()
     if dag_warnings:
         for dag_warning in dag_warnings:
             flash(dag_warning.message, "warning")
diff --git a/tests/cli/commands/test_connection_command.py 
b/tests/cli/commands/test_connection_command.py
index 12b3de6170..0c48cd099b 100644
--- a/tests/cli/commands/test_connection_command.py
+++ b/tests/cli/commands/test_connection_command.py
@@ -167,9 +167,7 @@ class TestCliExportConnections:
         def my_side_effect(_):
             raise Exception("dummy exception")
 
-        
mock_session.return_value.__enter__.return_value.query.return_value.order_by.side_effect
 = (
-            my_side_effect
-        )
+        mock_session.return_value.__enter__.return_value.scalars.side_effect = 
my_side_effect
         args = self.parser.parse_args(["connections", "export", 
output_filepath.as_posix()])
         with pytest.raises(Exception, match=r"dummy exception"):
             connection_command.connections_export(args)

Reply via email to