This is an automated email from the ASF dual-hosted git repository.
phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 26bb408d17 Refactor Sqlalchemy queries to 2.0 style (Part 4) (#32339)
26bb408d17 is described below
commit 26bb408d17d6eecb6f67e31e536f052a37a66db8
Author: Phani Kumar <[email protected]>
AuthorDate: Fri Jul 7 16:52:06 2023 +0530
Refactor Sqlalchemy queries to 2.0 style (Part 4) (#32339)
---
airflow/cli/commands/connection_command.py | 14 ++++++++------
airflow/secrets/metastore.py | 5 +++--
airflow/www/utils.py | 17 +++++++++--------
tests/cli/commands/test_connection_command.py | 4 +---
4 files changed, 21 insertions(+), 19 deletions(-)
diff --git a/airflow/cli/commands/connection_command.py
b/airflow/cli/commands/connection_command.py
index 63888c0ec9..e7b83e342e 100644
--- a/airflow/cli/commands/connection_command.py
+++ b/airflow/cli/commands/connection_command.py
@@ -26,6 +26,7 @@ from pathlib import Path
from typing import Any
from urllib.parse import urlsplit, urlunsplit
+from sqlalchemy import select
from sqlalchemy.orm import exc
from airflow.cli.simple_table import AirflowConsole
@@ -77,9 +78,10 @@ def connections_get(args):
def connections_list(args):
"""Lists all connections at the command line."""
with create_session() as session:
- query = session.query(Connection)
+ query = select(Connection)
if args.conn_id:
- query = query.filter(Connection.conn_id == args.conn_id)
+ query = query.where(Connection.conn_id == args.conn_id)
+ query = session.scalars(query)
conns = query.all()
AirflowConsole().print_as(
@@ -177,7 +179,7 @@ def connections_export(args):
raise SystemExit("Option `--serialization-format` may only be used
with file type `env`.")
with create_session() as session:
- connections =
session.query(Connection).order_by(Connection.conn_id).all()
+ connections =
session.scalars(select(Connection).order_by(Connection.conn_id)).all()
msg = _format_connections(
conns=connections,
@@ -265,7 +267,7 @@ def connections_add(args):
new_conn.set_extra(args.conn_extra)
with create_session() as session:
- if not session.query(Connection).filter(Connection.conn_id ==
new_conn.conn_id).first():
+ if not session.scalar(select(Connection).where(Connection.conn_id ==
new_conn.conn_id).limit(1)):
session.add(new_conn)
msg = "Successfully added `conn_id`={conn_id} : {uri}"
msg = msg.format(
@@ -293,7 +295,7 @@ def connections_delete(args):
"""Deletes connection from DB."""
with create_session() as session:
try:
- to_delete = session.query(Connection).filter(Connection.conn_id ==
args.conn_id).one()
+ to_delete =
session.scalars(select(Connection).where(Connection.conn_id ==
args.conn_id)).one()
except exc.NoResultFound:
raise SystemExit(f"Did not find a connection with
`conn_id`={args.conn_id}")
except exc.MultipleResultsFound:
@@ -326,7 +328,7 @@ def _import_helper(file_path: str, overwrite: bool) -> None:
print(f"Could not import connection. {e}")
continue
- existing_conn_id =
session.query(Connection.id).filter(Connection.conn_id == conn_id).scalar()
+ existing_conn_id =
session.scalar(select(Connection.id).where(Connection.conn_id == conn_id))
if existing_conn_id is not None:
if not overwrite:
print(f"Could not import connection {conn_id}: connection
already exists.")
diff --git a/airflow/secrets/metastore.py b/airflow/secrets/metastore.py
index dc81675d77..160c063dff 100644
--- a/airflow/secrets/metastore.py
+++ b/airflow/secrets/metastore.py
@@ -21,6 +21,7 @@ from __future__ import annotations
import warnings
from typing import TYPE_CHECKING
+from sqlalchemy import select
from sqlalchemy.orm import Session
from airflow.exceptions import RemovedInAirflow3Warning
@@ -38,7 +39,7 @@ class MetastoreBackend(BaseSecretsBackend):
def get_connection(self, conn_id: str, session: Session = NEW_SESSION) ->
Connection | None:
from airflow.models.connection import Connection
- conn = session.query(Connection).filter(Connection.conn_id ==
conn_id).first()
+ conn = session.scalar(select(Connection).where(Connection.conn_id ==
conn_id).limit(1))
session.expunge_all()
return conn
@@ -65,7 +66,7 @@ class MetastoreBackend(BaseSecretsBackend):
"""
from airflow.models.variable import Variable
- var_value = session.query(Variable).filter(Variable.key == key).first()
+ var_value = session.scalar(select(Variable).where(Variable.key ==
key).limit(1))
session.expunge_all()
if var_value:
return var_value.val
diff --git a/airflow/www/utils.py b/airflow/www/utils.py
index 4aaeda595b..1ca6289122 100644
--- a/airflow/www/utils.py
+++ b/airflow/www/utils.py
@@ -37,7 +37,7 @@ from pendulum.datetime import DateTime
from pygments import highlight, lexers
from pygments.formatters import HtmlFormatter
from pygments.lexer import Lexer
-from sqlalchemy import delete, func, types
+from sqlalchemy import delete, func, select, types
from sqlalchemy.ext.associationproxy import AssociationProxy
from sqlalchemy.sql import Select
@@ -68,16 +68,15 @@ def datetime_to_string(value: DateTime | None) -> str |
None:
def get_mapped_instances(task_instance, session):
- return (
- session.query(TaskInstance)
- .filter(
+ return session.scalars(
+ select(TaskInstance)
+ .where(
TaskInstance.dag_id == task_instance.dag_id,
TaskInstance.run_id == task_instance.run_id,
TaskInstance.task_id == task_instance.task_id,
)
.order_by(TaskInstance.map_index)
- .all()
- )
+ ).all()
def get_instance_with_map(task_instance, session):
@@ -179,14 +178,16 @@ def encode_dag_run(
def check_import_errors(fileloc, session):
# Check dag import errors
- import_errors =
session.query(errors.ImportError).filter(errors.ImportError.filename ==
fileloc).all()
+ import_errors = session.scalars(
+ select(errors.ImportError).where(errors.ImportError.filename ==
fileloc)
+ ).all()
if import_errors:
for import_error in import_errors:
flash("Broken DAG: [{ie.filename}]
{ie.stacktrace}".format(ie=import_error), "dag_import_error")
def check_dag_warnings(dag_id, session):
- dag_warnings = session.query(DagWarning).filter(DagWarning.dag_id ==
dag_id).all()
+ dag_warnings = session.scalars(select(DagWarning).where(DagWarning.dag_id
== dag_id)).all()
if dag_warnings:
for dag_warning in dag_warnings:
flash(dag_warning.message, "warning")
diff --git a/tests/cli/commands/test_connection_command.py
b/tests/cli/commands/test_connection_command.py
index 12b3de6170..0c48cd099b 100644
--- a/tests/cli/commands/test_connection_command.py
+++ b/tests/cli/commands/test_connection_command.py
@@ -167,9 +167,7 @@ class TestCliExportConnections:
def my_side_effect(_):
raise Exception("dummy exception")
-
mock_session.return_value.__enter__.return_value.query.return_value.order_by.side_effect
= (
- my_side_effect
- )
+ mock_session.return_value.__enter__.return_value.scalars.side_effect =
my_side_effect
args = self.parser.parse_args(["connections", "export",
output_filepath.as_posix()])
with pytest.raises(Exception, match=r"dummy exception"):
connection_command.connections_export(args)