This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3f6ac2f216 SQL query improvements in utils/db.py (#32518)
3f6ac2f216 is described below

commit 3f6ac2f216808cb3f227bb77ba751cb17fbd2a14
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Tue Jul 11 18:58:31 2023 +0800

    SQL query improvements in utils/db.py (#32518)
---
 airflow/utils/db.py | 38 +++++++++++++++++++-------------------
 1 file changed, 19 insertions(+), 19 deletions(-)

diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index 46ddcbfb34..f846b3f3a6 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -30,7 +30,6 @@ from tempfile import gettempdir
 from typing import TYPE_CHECKING, Callable, Generator, Iterable
 
 from sqlalchemy import Table, and_, column, delete, exc, func, inspect, or_, 
select, table, text, tuple_
-from sqlalchemy.orm.session import Session
 
 import airflow
 from airflow import settings
@@ -45,9 +44,10 @@ from airflow.utils.session import NEW_SESSION, 
create_session, provide_session
 if TYPE_CHECKING:
     from alembic.runtime.environment import EnvironmentContext
     from alembic.script import ScriptDirectory
-    from sqlalchemy.orm import Query
+    from sqlalchemy.orm import Query, Session
 
     from airflow.models.base import Base
+    from airflow.models.connection import Connection
 
 log = logging.getLogger(__name__)
 
@@ -90,9 +90,9 @@ def _format_airflow_moved_table_name(source_table, version, 
category):
 
 
 @provide_session
-def merge_conn(conn, session: Session = NEW_SESSION):
+def merge_conn(conn: Connection, session: Session = NEW_SESSION):
     """Add new Connection."""
-    if not 
session.scalar(select(conn.__class__).filter_by(conn_id=conn.conn_id).limit(1)):
+    if not session.scalar(select(1).where(conn.__class__.conn_id == 
conn.conn_id)):
         session.add(conn)
         session.commit()
 
@@ -957,20 +957,20 @@ def check_conn_id_duplicates(session: Session) -> 
Iterable[str]:
     """
     from airflow.models.connection import Connection
 
-    dups = []
     try:
-        dups = session.execute(
+        dups = session.scalars(
             
select(Connection.conn_id).group_by(Connection.conn_id).having(func.count() > 1)
         ).all()
     except (exc.OperationalError, exc.ProgrammingError):
         # fallback if tables hasn't been created yet
         session.rollback()
+        return
     if dups:
         yield (
             "Seems you have non unique conn_id in connection table.\n"
             "You have to manage those duplicate connections "
             "before upgrading the database.\n"
-            f"Duplicated conn_id: {[dup.conn_id for dup in dups]}"
+            f"Duplicated conn_id: {dups}"
         )
 
 
@@ -1057,11 +1057,11 @@ def check_table_for_duplicates(
     :param uniqueness: uniqueness constraint to evaluate against
     :param session:  session of the sqlalchemy
     """
-    minimal_table_obj = table(table_name, *[column(x) for x in uniqueness])
+    minimal_table_obj = table(table_name, *(column(x) for x in uniqueness))
     try:
         subquery = session.execute(
             select(minimal_table_obj, func.count().label("dupe_count"))
-            .group_by(*[text(x) for x in uniqueness])
+            .group_by(*(text(x) for x in uniqueness))
             .having(func.count() > text("1"))
             .subquery()
         )
@@ -1100,12 +1100,12 @@ def check_conn_type_null(session: Session) -> 
Iterable[str]:
     """
     from airflow.models.connection import Connection
 
-    n_nulls = []
     try:
         n_nulls = 
session.scalars(select(Connection.conn_id).where(Connection.conn_type.is_(None))).all()
     except (exc.OperationalError, exc.ProgrammingError, exc.InternalError):
         # fallback if tables hasn't been created yet
         session.rollback()
+        return
 
     if n_nulls:
         yield (
@@ -1113,7 +1113,7 @@ def check_conn_type_null(session: Session) -> 
Iterable[str]:
             "table must contain content.\n"
             "Make sure you don't have null "
             "in the conn_type column.\n"
-            f"Null conn_type conn_id: {list(n_nulls)}"
+            f"Null conn_type conn_id: {n_nulls}"
         )
 
 
@@ -1265,7 +1265,7 @@ def _dangling_against_dag_run(session, source_table, 
dag_run):
     )
 
     return (
-        select(*[c.label(c.name) for c in source_table.c])
+        select(*(c.label(c.name) for c in source_table.c))
         .join(dag_run, source_to_dag_run_join_cond, isouter=True)
         .where(dag_run.c.dag_id.is_(None))
     )
@@ -1306,9 +1306,9 @@ def _dangling_against_task_instance(session, 
source_table, dag_run, task_instanc
         )
 
     return (
-        select(*[c.label(c.name) for c in source_table.c])
-        .join(dag_run, dr_join_cond, isouter=True)
-        .join(task_instance, ti_join_cond, isouter=True)
+        select(*(c.label(c.name) for c in source_table.c))
+        .outerjoin(dag_run, dr_join_cond)
+        .outerjoin(task_instance, ti_join_cond)
         .where(or_(task_instance.c.dag_id.is_(None), 
dag_run.c.dag_id.is_(None)))
     )
 
@@ -1335,9 +1335,9 @@ def _move_duplicate_data_to_new_table(
     dialect_name = bind.dialect.name
 
     query = (
-        select(*[getattr(source_table.c, x.name).label(str(x.name)) for x in 
source_table.columns])
+        select(*(source_table.c[x.name].label(str(x.name)) for x in 
source_table.columns))
         .select_from(source_table)
-        .join(subquery, and_(*[getattr(source_table.c, x) == 
getattr(subquery.c, x) for x in uniqueness]))
+        .join(subquery, and_(*(source_table.c[x] == subquery.c[x] for x in 
uniqueness)))
     )
 
     _create_table_as(
@@ -1353,7 +1353,7 @@ def _move_duplicate_data_to_new_table(
 
     metadata = reflect_tables([target_table_name], session)
     target_table = metadata.tables[target_table_name]
-    where_clause = and_(*[getattr(source_table.c, x) == 
getattr(target_table.c, x) for x in uniqueness])
+    where_clause = and_(*(source_table.c[x] == target_table.c[x] for x in 
uniqueness))
 
     if dialect_name == "sqlite":
         subq = 
query.selectable.with_only_columns([text(f"{source_table}.ROWID")])
@@ -1410,7 +1410,7 @@ def check_bad_references(session: Session) -> 
Iterable[str]:
         (TaskFail, "2.3", missing_ti_config),
         (XCom, "2.3", missing_ti_config),
     ]
-    metadata = reflect_tables([*[x[0] for x in models_list], DagRun, 
TaskInstance], session)
+    metadata = reflect_tables([*(x[0] for x in models_list), DagRun, 
TaskInstance], session)
 
     if (
         not metadata.tables

Reply via email to