This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2efee57594f Fix remaining MyPy type errors in utils/ (#56982)
2efee57594f is described below

commit 2efee57594f25f708fe2c5bf6a3a2495767a25a9
Author: LI,JHE-CHEN <[email protected]>
AuthorDate: Mon Oct 27 09:17:00 2025 -0400

    Fix remaining MyPy type errors in utils/ (#56982)
---
 .../src/airflow/utils/cli_action_loggers.py        |  2 +-
 airflow-core/src/airflow/utils/context.py          |  1 -
 airflow-core/src/airflow/utils/db.py               | 46 +++++++++++++++-------
 airflow-core/src/airflow/utils/db_cleanup.py       | 10 ++---
 airflow-core/src/airflow/utils/session.py          |  2 +-
 airflow-core/src/airflow/utils/sqlalchemy.py       | 20 +++++++---
 airflow-core/tests/unit/utils/test_sqlalchemy.py   |  1 +
 7 files changed, 55 insertions(+), 27 deletions(-)

diff --git a/airflow-core/src/airflow/utils/cli_action_loggers.py 
b/airflow-core/src/airflow/utils/cli_action_loggers.py
index dd3d8d59fa8..6ec28079af5 100644
--- a/airflow-core/src/airflow/utils/cli_action_loggers.py
+++ b/airflow-core/src/airflow/utils/cli_action_loggers.py
@@ -130,7 +130,7 @@ def default_action_log(
         # Use bulk_insert_mappings here to avoid importing all models (which 
using the classes does) early
         # on in the CLI
         session.bulk_insert_mappings(
-            Log,
+            Log,  # type: ignore[arg-type]
             [
                 {
                     "event": f"cli_{sub_command}",
diff --git a/airflow-core/src/airflow/utils/context.py 
b/airflow-core/src/airflow/utils/context.py
index c9013b152c4..abf6bb1a530 100644
--- a/airflow-core/src/airflow/utils/context.py
+++ b/airflow-core/src/airflow/utils/context.py
@@ -147,7 +147,6 @@ class OutletEventAccessors(OutletEventAccessorsSDK):
 
         if asset is None:
             raise ValueError("No active asset found with either name or uri.")
-
         return asset.to_public()
 
 
diff --git a/airflow-core/src/airflow/utils/db.py 
b/airflow-core/src/airflow/utils/db.py
index 36308d7b148..e65632f89ee 100644
--- a/airflow-core/src/airflow/utils/db.py
+++ b/airflow-core/src/airflow/utils/db.py
@@ -59,6 +59,7 @@ from airflow.models import import_all_models
 from airflow.utils import helpers
 from airflow.utils.db_manager import RunDBManager
 from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.utils.sqlalchemy import get_dialect_name
 from airflow.utils.task_instance_session import 
get_current_task_instance_session
 
 USE_PSYCOPG3: bool
@@ -82,7 +83,7 @@ if TYPE_CHECKING:
     from sqlalchemy.engine import Row
     from sqlalchemy.ext.asyncio import AsyncSession
     from sqlalchemy.orm import Session
-    from sqlalchemy.sql.elements import ClauseElement, TextClause
+    from sqlalchemy.sql.elements import ColumnElement, TextClause
     from sqlalchemy.sql.selectable import Select
 
     from airflow.models.connection import Connection
@@ -1346,10 +1347,14 @@ def create_global_lock(
     lock_timeout: int = 1800,
 ) -> Generator[None, None, None]:
     """Contextmanager that will create and teardown a global db lock."""
-    conn = session.get_bind().connect()
-    dialect = conn.dialect
+    bind = session.get_bind()
+    if hasattr(bind, "connect"):
+        conn = bind.connect()
+    else:
+        conn = bind
+    dialect_name = get_dialect_name(session)
     try:
-        if dialect.name == "postgresql":
+        if dialect_name == "postgresql":
             if USE_PSYCOPG3:
                 # psycopg3 doesn't support parameters for `SET`. Use 
`set_config` instead.
                 # The timeout value must be passed as a string of milliseconds.
@@ -1361,21 +1366,32 @@ def create_global_lock(
             else:
                 conn.execute(text("SET LOCK_TIMEOUT to :timeout"), {"timeout": 
lock_timeout})
                 conn.execute(text("SELECT pg_advisory_lock(:id)"), {"id": 
lock.value})
-        elif dialect.name == "mysql" and dialect.server_version_info >= (5, 6):
+        elif (
+            dialect_name == "mysql"
+            and conn.dialect.server_version_info
+            and conn.dialect.server_version_info >= (5, 6)
+        ):
             conn.execute(text("SELECT GET_LOCK(:id, :timeout)"), {"id": 
str(lock), "timeout": lock_timeout})
 
         yield
     finally:
-        if dialect.name == "postgresql":
+        if dialect_name == "postgresql":
             if USE_PSYCOPG3:
                 # Use set_config() to reset the timeout to its default (0 = 
off/wait forever).
                 conn.execute(text("SELECT set_config('lock_timeout', '0', 
false)"))
             else:
                 conn.execute(text("SET LOCK_TIMEOUT TO DEFAULT"))
-            (unlocked,) = conn.execute(text("SELECT pg_advisory_unlock(:id)"), 
{"id": lock.value}).fetchone()
+            result = conn.execute(text("SELECT pg_advisory_unlock(:id)"), 
{"id": lock.value}).fetchone()
+            if result is None:
+                raise RuntimeError("Error releasing DB lock!")
+            (unlocked,) = result
             if not unlocked:
                 raise RuntimeError("Error releasing DB lock!")
-        elif dialect.name == "mysql" and dialect.server_version_info >= (5, 6):
+        elif (
+            dialect_name == "mysql"
+            and conn.dialect.server_version_info
+            and conn.dialect.server_version_info >= (5, 6)
+        ):
             conn.execute(text("select RELEASE_LOCK(:id)"), {"id": str(lock)})
 
 
@@ -1460,7 +1476,8 @@ def get_query_count(query_stmt: Select, *, session: 
Session) -> int:
     :meta private:
     """
     count_stmt = 
select(func.count()).select_from(query_stmt.order_by(None).subquery())
-    return session.scalar(count_stmt)
+    result = session.scalar(count_stmt)
+    return result or 0
 
 
 async def get_query_count_async(statement: Select, *, session: AsyncSession) 
-> int:
@@ -1475,7 +1492,8 @@ async def get_query_count_async(statement: Select, *, 
session: AsyncSession) ->
     :meta private:
     """
     count_stmt = 
select(func.count()).select_from(statement.order_by(None).subquery())
-    return await session.scalar(count_stmt)
+    result = await session.scalar(count_stmt)
+    return result or 0
 
 
 def check_query_exists(query_stmt: Select, *, session: Session) -> bool:
@@ -1493,7 +1511,7 @@ def check_query_exists(query_stmt: Select, *, session: 
Session) -> bool:
     return bool(session.scalar(count_stmt))
 
 
-def exists_query(*where: ClauseElement, session: Session) -> bool:
+def exists_query(*where: ColumnElement[bool], session: Session) -> bool:
     """
     Check whether there is at least one row matching given clauses.
 
@@ -1527,8 +1545,8 @@ class LazySelectSequence(Sequence[T]):
     :meta private:
     """
 
-    _select_asc: ClauseElement
-    _select_desc: ClauseElement
+    _select_asc: Select
+    _select_desc: Select
     _session: Session = attrs.field(kw_only=True, 
factory=get_current_task_instance_session)
     _len: int | None = attrs.field(init=False, default=None)
 
@@ -1537,7 +1555,7 @@ class LazySelectSequence(Sequence[T]):
         cls,
         select: Select,
         *,
-        order_by: Sequence[ClauseElement],
+        order_by: Sequence[ColumnElement],
         session: Session | None = None,
     ) -> Self:
         s1 = select
diff --git a/airflow-core/src/airflow/utils/db_cleanup.py 
b/airflow-core/src/airflow/utils/db_cleanup.py
index 5dae8dbe847..18b7413f64d 100644
--- a/airflow-core/src/airflow/utils/db_cleanup.py
+++ b/airflow-core/src/airflow/utils/db_cleanup.py
@@ -251,7 +251,7 @@ def _do_delete(
                 )
             else:
                 delete = source_table.delete().where(
-                    and_(col == target_table.c[col.name] for col in 
source_table.primary_key.columns)
+                    and_(*[col == target_table.c[col.name] for col in 
source_table.primary_key.columns])
                 )
             logger.debug("delete statement:\n%s", delete.compile())
             session.execute(delete)
@@ -270,7 +270,7 @@ def _do_delete(
 
 def _subquery_keep_last(
     *, recency_column, keep_last_filters, group_by_columns, max_date_colname, 
session: Session
-) -> Query:
+):
     subquery = select(*group_by_columns, 
func.max(recency_column).label(max_date_colname))
 
     if keep_last_filters is not None:
@@ -316,7 +316,7 @@ def _build_query(
     conditions = [base_table_recency_col < clean_before_timestamp]
     if keep_last:
         max_date_col_name = "max_date_per_group"
-        group_by_columns = [column(x) for x in keep_last_group_by]
+        group_by_columns: list[Any] = [column(x) for x in keep_last_group_by]
         subquery = _subquery_keep_last(
             recency_column=recency_column,
             keep_last_filters=keep_last_filters,
@@ -327,7 +327,7 @@ def _build_query(
         query = query.select_from(base_table).outerjoin(
             subquery,
             and_(
-                *[base_table.c[x] == subquery.c[x] for x in 
keep_last_group_by],
+                *[base_table.c[x] == subquery.c[x] for x in 
keep_last_group_by],  # type: ignore[attr-defined]
                 base_table_recency_col == column(max_date_col_name),
             ),
         )
@@ -475,7 +475,7 @@ def _get_archived_table_names(table_names: list[str] | 
None, session: Session) -
     inspector = inspect(session.bind)
     db_table_names = [
         x
-        for x in inspector.get_table_names()
+        for x in (inspector.get_table_names() if inspector else [])
         if x.startswith(ARCHIVE_TABLE_PREFIX) or x in 
ARCHIVED_TABLES_FROM_DB_MIGRATIONS
     ]
     effective_table_names, _ = _effective_table_names(table_names=table_names)
diff --git a/airflow-core/src/airflow/utils/session.py 
b/airflow-core/src/airflow/utils/session.py
index 211c1645320..75509c14e01 100644
--- a/airflow-core/src/airflow/utils/session.py
+++ b/airflow-core/src/airflow/utils/session.py
@@ -97,7 +97,7 @@ def provide_session(func: Callable[PS, RT]) -> Callable[PS, 
RT]:
         if "session" in kwargs or session_args_idx < len(args):
             return func(*args, **kwargs)
         with create_session() as session:
-            return func(*args, session=session, **kwargs)
+            return func(*args, session=session, **kwargs)  # type: 
ignore[arg-type]
 
     return wrapper
 
diff --git a/airflow-core/src/airflow/utils/sqlalchemy.py 
b/airflow-core/src/airflow/utils/sqlalchemy.py
index 01dcf0b0848..0b0a0c7533c 100644
--- a/airflow-core/src/airflow/utils/sqlalchemy.py
+++ b/airflow-core/src/airflow/utils/sqlalchemy.py
@@ -42,6 +42,7 @@ if TYPE_CHECKING:
     from sqlalchemy.exc import OperationalError
     from sqlalchemy.orm import Query, Session
     from sqlalchemy.sql import Select
+    from sqlalchemy.sql.elements import ColumnElement
     from sqlalchemy.types import TypeEngine
 
     from airflow.typing_compat import Self
@@ -53,7 +54,7 @@ try:
     from sqlalchemy.orm import mapped_column
 except ImportError:
     # fallback for SQLAlchemy < 2.0
-    def mapped_column(*args, **kwargs):
+    def mapped_column(*args, **kwargs):  # type: ignore[misc]
         from sqlalchemy import Column
 
         return Column(*args, **kwargs)
@@ -312,7 +313,7 @@ class ExecutorConfigType(PickleType):
             return False
 
 
-def nulls_first(col, session: Session) -> dict[str, Any]:
+def nulls_first(col: ColumnElement, session: Session) -> ColumnElement:
     """
     Specify *NULLS FIRST* to the column ordering.
 
@@ -356,12 +357,19 @@ def with_row_locks(
     :param kwargs: Extra kwargs to pass to with_for_update (of, nowait, 
skip_locked, etc)
     :return: updated query
     """
-    dialect = session.bind.dialect
+    try:
+        dialect_name = get_dialect_name(session)
+    except ValueError:
+        return query
+    if not dialect_name:
+        return query
 
     # Don't use row level locks if the MySQL dialect (Mariadb & MySQL < 8) 
does not support it.
     if not USE_ROW_LEVEL_LOCKING:
         return query
-    if dialect.name == "mysql" and not dialect.supports_for_update_of:
+    if dialect_name == "mysql" and not getattr(
+        session.bind.dialect if session.bind else None, 
"supports_for_update_of", False
+    ):
         return query
     if nowait:
         kwargs["nowait"] = True
@@ -448,7 +456,9 @@ def is_lock_not_available_error(error: OperationalError):
     #               is set.'
     # MySQL: 1205, 'Lock wait timeout exceeded; try restarting transaction
     #              (when NOWAIT isn't available)
-    db_err_code = getattr(error.orig, "pgcode", None) or error.orig.args[0]
+    db_err_code = getattr(error.orig, "pgcode", None) or (
+        error.orig.args[0] if error.orig and error.orig.args else None
+    )
 
     # We could test if error.orig is an instance of
     # psycopg2.errors.LockNotAvailable/_mysql_exceptions.OperationalError, but 
that involves
diff --git a/airflow-core/tests/unit/utils/test_sqlalchemy.py 
b/airflow-core/tests/unit/utils/test_sqlalchemy.py
index c0a4761a492..8d870a92a95 100644
--- a/airflow-core/tests/unit/utils/test_sqlalchemy.py
+++ b/airflow-core/tests/unit/utils/test_sqlalchemy.py
@@ -169,6 +169,7 @@ class TestSqlAlchemyUtils:
         session = mock.Mock()
         session.bind.dialect.name = dialect
         session.bind.dialect.supports_for_update_of = supports_for_update_of
+        session.get_bind.return_value = session.bind
         with mock.patch("airflow.utils.sqlalchemy.USE_ROW_LEVEL_LOCKING", 
use_row_level_lock_conf):
             returned_value = with_row_locks(query=query, session=session, 
nowait=True)
 

Reply via email to