This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 3f6ac2f216 SQL query improvements in utils/db.py (#32518)
3f6ac2f216 is described below
commit 3f6ac2f216808cb3f227bb77ba751cb17fbd2a14
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Tue Jul 11 18:58:31 2023 +0800
SQL query improvements in utils/db.py (#32518)
---
airflow/utils/db.py | 38 +++++++++++++++++++-------------------
1 file changed, 19 insertions(+), 19 deletions(-)
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index 46ddcbfb34..f846b3f3a6 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -30,7 +30,6 @@ from tempfile import gettempdir
from typing import TYPE_CHECKING, Callable, Generator, Iterable
from sqlalchemy import Table, and_, column, delete, exc, func, inspect, or_,
select, table, text, tuple_
-from sqlalchemy.orm.session import Session
import airflow
from airflow import settings
@@ -45,9 +44,10 @@ from airflow.utils.session import NEW_SESSION,
create_session, provide_session
if TYPE_CHECKING:
from alembic.runtime.environment import EnvironmentContext
from alembic.script import ScriptDirectory
- from sqlalchemy.orm import Query
+ from sqlalchemy.orm import Query, Session
from airflow.models.base import Base
+ from airflow.models.connection import Connection
log = logging.getLogger(__name__)
@@ -90,9 +90,9 @@ def _format_airflow_moved_table_name(source_table, version,
category):
@provide_session
-def merge_conn(conn, session: Session = NEW_SESSION):
+def merge_conn(conn: Connection, session: Session = NEW_SESSION):
"""Add new Connection."""
- if not
session.scalar(select(conn.__class__).filter_by(conn_id=conn.conn_id).limit(1)):
+ if not session.scalar(select(1).where(conn.__class__.conn_id ==
conn.conn_id)):
session.add(conn)
session.commit()
@@ -957,20 +957,20 @@ def check_conn_id_duplicates(session: Session) ->
Iterable[str]:
"""
from airflow.models.connection import Connection
- dups = []
try:
- dups = session.execute(
+ dups = session.scalars(
select(Connection.conn_id).group_by(Connection.conn_id).having(func.count() > 1)
).all()
except (exc.OperationalError, exc.ProgrammingError):
# fallback if tables hasn't been created yet
session.rollback()
+ return
if dups:
yield (
"Seems you have non unique conn_id in connection table.\n"
"You have to manage those duplicate connections "
"before upgrading the database.\n"
- f"Duplicated conn_id: {[dup.conn_id for dup in dups]}"
+ f"Duplicated conn_id: {dups}"
)
@@ -1057,11 +1057,11 @@ def check_table_for_duplicates(
:param uniqueness: uniqueness constraint to evaluate against
:param session: session of the sqlalchemy
"""
- minimal_table_obj = table(table_name, *[column(x) for x in uniqueness])
+ minimal_table_obj = table(table_name, *(column(x) for x in uniqueness))
try:
subquery = session.execute(
select(minimal_table_obj, func.count().label("dupe_count"))
- .group_by(*[text(x) for x in uniqueness])
+ .group_by(*(text(x) for x in uniqueness))
.having(func.count() > text("1"))
.subquery()
)
@@ -1100,12 +1100,12 @@ def check_conn_type_null(session: Session) ->
Iterable[str]:
"""
from airflow.models.connection import Connection
- n_nulls = []
try:
n_nulls =
session.scalars(select(Connection.conn_id).where(Connection.conn_type.is_(None))).all()
except (exc.OperationalError, exc.ProgrammingError, exc.InternalError):
# fallback if tables hasn't been created yet
session.rollback()
+ return
if n_nulls:
yield (
@@ -1113,7 +1113,7 @@ def check_conn_type_null(session: Session) ->
Iterable[str]:
"table must contain content.\n"
"Make sure you don't have null "
"in the conn_type column.\n"
- f"Null conn_type conn_id: {list(n_nulls)}"
+ f"Null conn_type conn_id: {n_nulls}"
)
@@ -1265,7 +1265,7 @@ def _dangling_against_dag_run(session, source_table,
dag_run):
)
return (
- select(*[c.label(c.name) for c in source_table.c])
+ select(*(c.label(c.name) for c in source_table.c))
.join(dag_run, source_to_dag_run_join_cond, isouter=True)
.where(dag_run.c.dag_id.is_(None))
)
@@ -1306,9 +1306,9 @@ def _dangling_against_task_instance(session,
source_table, dag_run, task_instanc
)
return (
- select(*[c.label(c.name) for c in source_table.c])
- .join(dag_run, dr_join_cond, isouter=True)
- .join(task_instance, ti_join_cond, isouter=True)
+ select(*(c.label(c.name) for c in source_table.c))
+ .outerjoin(dag_run, dr_join_cond)
+ .outerjoin(task_instance, ti_join_cond)
.where(or_(task_instance.c.dag_id.is_(None),
dag_run.c.dag_id.is_(None)))
)
@@ -1335,9 +1335,9 @@ def _move_duplicate_data_to_new_table(
dialect_name = bind.dialect.name
query = (
- select(*[getattr(source_table.c, x.name).label(str(x.name)) for x in
source_table.columns])
+ select(*(source_table.c[x.name].label(str(x.name)) for x in
source_table.columns))
.select_from(source_table)
- .join(subquery, and_(*[getattr(source_table.c, x) ==
getattr(subquery.c, x) for x in uniqueness]))
+ .join(subquery, and_(*(source_table.c[x] == subquery.c[x] for x in
uniqueness)))
)
_create_table_as(
@@ -1353,7 +1353,7 @@ def _move_duplicate_data_to_new_table(
metadata = reflect_tables([target_table_name], session)
target_table = metadata.tables[target_table_name]
- where_clause = and_(*[getattr(source_table.c, x) ==
getattr(target_table.c, x) for x in uniqueness])
+ where_clause = and_(*(source_table.c[x] == target_table.c[x] for x in
uniqueness))
if dialect_name == "sqlite":
subq =
query.selectable.with_only_columns([text(f"{source_table}.ROWID")])
@@ -1410,7 +1410,7 @@ def check_bad_references(session: Session) ->
Iterable[str]:
(TaskFail, "2.3", missing_ti_config),
(XCom, "2.3", missing_ti_config),
]
- metadata = reflect_tables([*[x[0] for x in models_list], DagRun,
TaskInstance], session)
+ metadata = reflect_tables([*(x[0] for x in models_list), DagRun,
TaskInstance], session)
if (
not metadata.tables