This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 2efee57594f Fix remaining MyPy type errors in utils/ (#56982)
2efee57594f is described below
commit 2efee57594f25f708fe2c5bf6a3a2495767a25a9
Author: LI,JHE-CHEN <[email protected]>
AuthorDate: Mon Oct 27 09:17:00 2025 -0400
Fix remaining MyPy type errors in utils/ (#56982)
---
.../src/airflow/utils/cli_action_loggers.py | 2 +-
airflow-core/src/airflow/utils/context.py | 1 -
airflow-core/src/airflow/utils/db.py | 46 +++++++++++++++-------
airflow-core/src/airflow/utils/db_cleanup.py | 10 ++---
airflow-core/src/airflow/utils/session.py | 2 +-
airflow-core/src/airflow/utils/sqlalchemy.py | 20 +++++++---
airflow-core/tests/unit/utils/test_sqlalchemy.py | 1 +
7 files changed, 55 insertions(+), 27 deletions(-)
diff --git a/airflow-core/src/airflow/utils/cli_action_loggers.py
b/airflow-core/src/airflow/utils/cli_action_loggers.py
index dd3d8d59fa8..6ec28079af5 100644
--- a/airflow-core/src/airflow/utils/cli_action_loggers.py
+++ b/airflow-core/src/airflow/utils/cli_action_loggers.py
@@ -130,7 +130,7 @@ def default_action_log(
# Use bulk_insert_mappings here to avoid importing all models (which
using the classes does) early
# on in the CLI
session.bulk_insert_mappings(
- Log,
+ Log, # type: ignore[arg-type]
[
{
"event": f"cli_{sub_command}",
diff --git a/airflow-core/src/airflow/utils/context.py
b/airflow-core/src/airflow/utils/context.py
index c9013b152c4..abf6bb1a530 100644
--- a/airflow-core/src/airflow/utils/context.py
+++ b/airflow-core/src/airflow/utils/context.py
@@ -147,7 +147,6 @@ class OutletEventAccessors(OutletEventAccessorsSDK):
if asset is None:
raise ValueError("No active asset found with either name or uri.")
-
return asset.to_public()
diff --git a/airflow-core/src/airflow/utils/db.py
b/airflow-core/src/airflow/utils/db.py
index 36308d7b148..e65632f89ee 100644
--- a/airflow-core/src/airflow/utils/db.py
+++ b/airflow-core/src/airflow/utils/db.py
@@ -59,6 +59,7 @@ from airflow.models import import_all_models
from airflow.utils import helpers
from airflow.utils.db_manager import RunDBManager
from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.utils.sqlalchemy import get_dialect_name
from airflow.utils.task_instance_session import
get_current_task_instance_session
USE_PSYCOPG3: bool
@@ -82,7 +83,7 @@ if TYPE_CHECKING:
from sqlalchemy.engine import Row
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
- from sqlalchemy.sql.elements import ClauseElement, TextClause
+ from sqlalchemy.sql.elements import ColumnElement, TextClause
from sqlalchemy.sql.selectable import Select
from airflow.models.connection import Connection
@@ -1346,10 +1347,14 @@ def create_global_lock(
lock_timeout: int = 1800,
) -> Generator[None, None, None]:
"""Contextmanager that will create and teardown a global db lock."""
- conn = session.get_bind().connect()
- dialect = conn.dialect
+ bind = session.get_bind()
+ if hasattr(bind, "connect"):
+ conn = bind.connect()
+ else:
+ conn = bind
+ dialect_name = get_dialect_name(session)
try:
- if dialect.name == "postgresql":
+ if dialect_name == "postgresql":
if USE_PSYCOPG3:
# psycopg3 doesn't support parameters for `SET`. Use
`set_config` instead.
# The timeout value must be passed as a string of milliseconds.
@@ -1361,21 +1366,32 @@ def create_global_lock(
else:
conn.execute(text("SET LOCK_TIMEOUT to :timeout"), {"timeout":
lock_timeout})
conn.execute(text("SELECT pg_advisory_lock(:id)"), {"id":
lock.value})
- elif dialect.name == "mysql" and dialect.server_version_info >= (5, 6):
+ elif (
+ dialect_name == "mysql"
+ and conn.dialect.server_version_info
+ and conn.dialect.server_version_info >= (5, 6)
+ ):
conn.execute(text("SELECT GET_LOCK(:id, :timeout)"), {"id":
str(lock), "timeout": lock_timeout})
yield
finally:
- if dialect.name == "postgresql":
+ if dialect_name == "postgresql":
if USE_PSYCOPG3:
# Use set_config() to reset the timeout to its default (0 =
off/wait forever).
conn.execute(text("SELECT set_config('lock_timeout', '0',
false)"))
else:
conn.execute(text("SET LOCK_TIMEOUT TO DEFAULT"))
- (unlocked,) = conn.execute(text("SELECT pg_advisory_unlock(:id)"),
{"id": lock.value}).fetchone()
+ result = conn.execute(text("SELECT pg_advisory_unlock(:id)"),
{"id": lock.value}).fetchone()
+ if result is None:
+ raise RuntimeError("Error releasing DB lock!")
+ (unlocked,) = result
if not unlocked:
raise RuntimeError("Error releasing DB lock!")
- elif dialect.name == "mysql" and dialect.server_version_info >= (5, 6):
+ elif (
+ dialect_name == "mysql"
+ and conn.dialect.server_version_info
+ and conn.dialect.server_version_info >= (5, 6)
+ ):
conn.execute(text("select RELEASE_LOCK(:id)"), {"id": str(lock)})
@@ -1460,7 +1476,8 @@ def get_query_count(query_stmt: Select, *, session:
Session) -> int:
:meta private:
"""
count_stmt =
select(func.count()).select_from(query_stmt.order_by(None).subquery())
- return session.scalar(count_stmt)
+ result = session.scalar(count_stmt)
+ return result or 0
async def get_query_count_async(statement: Select, *, session: AsyncSession)
-> int:
@@ -1475,7 +1492,8 @@ async def get_query_count_async(statement: Select, *,
session: AsyncSession) ->
:meta private:
"""
count_stmt =
select(func.count()).select_from(statement.order_by(None).subquery())
- return await session.scalar(count_stmt)
+ result = await session.scalar(count_stmt)
+ return result or 0
def check_query_exists(query_stmt: Select, *, session: Session) -> bool:
@@ -1493,7 +1511,7 @@ def check_query_exists(query_stmt: Select, *, session:
Session) -> bool:
return bool(session.scalar(count_stmt))
-def exists_query(*where: ClauseElement, session: Session) -> bool:
+def exists_query(*where: ColumnElement[bool], session: Session) -> bool:
"""
Check whether there is at least one row matching given clauses.
@@ -1527,8 +1545,8 @@ class LazySelectSequence(Sequence[T]):
:meta private:
"""
- _select_asc: ClauseElement
- _select_desc: ClauseElement
+ _select_asc: Select
+ _select_desc: Select
_session: Session = attrs.field(kw_only=True,
factory=get_current_task_instance_session)
_len: int | None = attrs.field(init=False, default=None)
@@ -1537,7 +1555,7 @@ class LazySelectSequence(Sequence[T]):
cls,
select: Select,
*,
- order_by: Sequence[ClauseElement],
+ order_by: Sequence[ColumnElement],
session: Session | None = None,
) -> Self:
s1 = select
diff --git a/airflow-core/src/airflow/utils/db_cleanup.py
b/airflow-core/src/airflow/utils/db_cleanup.py
index 5dae8dbe847..18b7413f64d 100644
--- a/airflow-core/src/airflow/utils/db_cleanup.py
+++ b/airflow-core/src/airflow/utils/db_cleanup.py
@@ -251,7 +251,7 @@ def _do_delete(
)
else:
delete = source_table.delete().where(
- and_(col == target_table.c[col.name] for col in
source_table.primary_key.columns)
+ and_(*[col == target_table.c[col.name] for col in
source_table.primary_key.columns])
)
logger.debug("delete statement:\n%s", delete.compile())
session.execute(delete)
@@ -270,7 +270,7 @@ def _do_delete(
def _subquery_keep_last(
*, recency_column, keep_last_filters, group_by_columns, max_date_colname,
session: Session
-) -> Query:
+):
subquery = select(*group_by_columns,
func.max(recency_column).label(max_date_colname))
if keep_last_filters is not None:
@@ -316,7 +316,7 @@ def _build_query(
conditions = [base_table_recency_col < clean_before_timestamp]
if keep_last:
max_date_col_name = "max_date_per_group"
- group_by_columns = [column(x) for x in keep_last_group_by]
+ group_by_columns: list[Any] = [column(x) for x in keep_last_group_by]
subquery = _subquery_keep_last(
recency_column=recency_column,
keep_last_filters=keep_last_filters,
@@ -327,7 +327,7 @@ def _build_query(
query = query.select_from(base_table).outerjoin(
subquery,
and_(
- *[base_table.c[x] == subquery.c[x] for x in
keep_last_group_by],
+ *[base_table.c[x] == subquery.c[x] for x in
keep_last_group_by], # type: ignore[attr-defined]
base_table_recency_col == column(max_date_col_name),
),
)
@@ -475,7 +475,7 @@ def _get_archived_table_names(table_names: list[str] |
None, session: Session) -
inspector = inspect(session.bind)
db_table_names = [
x
- for x in inspector.get_table_names()
+ for x in (inspector.get_table_names() if inspector else [])
if x.startswith(ARCHIVE_TABLE_PREFIX) or x in
ARCHIVED_TABLES_FROM_DB_MIGRATIONS
]
effective_table_names, _ = _effective_table_names(table_names=table_names)
diff --git a/airflow-core/src/airflow/utils/session.py
b/airflow-core/src/airflow/utils/session.py
index 211c1645320..75509c14e01 100644
--- a/airflow-core/src/airflow/utils/session.py
+++ b/airflow-core/src/airflow/utils/session.py
@@ -97,7 +97,7 @@ def provide_session(func: Callable[PS, RT]) -> Callable[PS,
RT]:
if "session" in kwargs or session_args_idx < len(args):
return func(*args, **kwargs)
with create_session() as session:
- return func(*args, session=session, **kwargs)
+ return func(*args, session=session, **kwargs) # type:
ignore[arg-type]
return wrapper
diff --git a/airflow-core/src/airflow/utils/sqlalchemy.py
b/airflow-core/src/airflow/utils/sqlalchemy.py
index 01dcf0b0848..0b0a0c7533c 100644
--- a/airflow-core/src/airflow/utils/sqlalchemy.py
+++ b/airflow-core/src/airflow/utils/sqlalchemy.py
@@ -42,6 +42,7 @@ if TYPE_CHECKING:
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Query, Session
from sqlalchemy.sql import Select
+ from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.types import TypeEngine
from airflow.typing_compat import Self
@@ -53,7 +54,7 @@ try:
from sqlalchemy.orm import mapped_column
except ImportError:
# fallback for SQLAlchemy < 2.0
- def mapped_column(*args, **kwargs):
+ def mapped_column(*args, **kwargs): # type: ignore[misc]
from sqlalchemy import Column
return Column(*args, **kwargs)
@@ -312,7 +313,7 @@ class ExecutorConfigType(PickleType):
return False
-def nulls_first(col, session: Session) -> dict[str, Any]:
+def nulls_first(col: ColumnElement, session: Session) -> ColumnElement:
"""
Specify *NULLS FIRST* to the column ordering.
@@ -356,12 +357,19 @@ def with_row_locks(
:param kwargs: Extra kwargs to pass to with_for_update (of, nowait,
skip_locked, etc)
:return: updated query
"""
- dialect = session.bind.dialect
+ try:
+ dialect_name = get_dialect_name(session)
+ except ValueError:
+ return query
+ if not dialect_name:
+ return query
# Don't use row level locks if the MySQL dialect (Mariadb & MySQL < 8)
does not support it.
if not USE_ROW_LEVEL_LOCKING:
return query
- if dialect.name == "mysql" and not dialect.supports_for_update_of:
+ if dialect_name == "mysql" and not getattr(
+ session.bind.dialect if session.bind else None,
"supports_for_update_of", False
+ ):
return query
if nowait:
kwargs["nowait"] = True
@@ -448,7 +456,9 @@ def is_lock_not_available_error(error: OperationalError):
# is set.'
# MySQL: 1205, 'Lock wait timeout exceeded; try restarting transaction
# (when NOWAIT isn't available)
- db_err_code = getattr(error.orig, "pgcode", None) or error.orig.args[0]
+ db_err_code = getattr(error.orig, "pgcode", None) or (
+ error.orig.args[0] if error.orig and error.orig.args else None
+ )
# We could test if error.orig is an instance of
# psycopg2.errors.LockNotAvailable/_mysql_exceptions.OperationalError, but
that involves
diff --git a/airflow-core/tests/unit/utils/test_sqlalchemy.py
b/airflow-core/tests/unit/utils/test_sqlalchemy.py
index c0a4761a492..8d870a92a95 100644
--- a/airflow-core/tests/unit/utils/test_sqlalchemy.py
+++ b/airflow-core/tests/unit/utils/test_sqlalchemy.py
@@ -169,6 +169,7 @@ class TestSqlAlchemyUtils:
session = mock.Mock()
session.bind.dialect.name = dialect
session.bind.dialect.supports_for_update_of = supports_for_update_of
+ session.get_bind.return_value = session.bind
with mock.patch("airflow.utils.sqlalchemy.USE_ROW_LEVEL_LOCKING",
use_row_level_lock_conf):
returned_value = with_row_locks(query=query, session=session,
nowait=True)