This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1103692e4bc Fix MyPy type errors in airflow_core models in 
dag.py,xcom.py for Sqlalchemy 2 migration (#57323)
1103692e4bc is described below

commit 1103692e4bc05ebef4724c97d49190ce87ddf950
Author: Anusha Kovi <[email protected]>
AuthorDate: Mon Oct 27 18:55:28 2025 +0530

    Fix MyPy type errors in airflow_core models in dag.py,xcom.py for 
Sqlalchemy 2 migration (#57323)
---
 airflow-core/src/airflow/models/dag.py  | 27 +++++++++++++++------------
 airflow-core/src/airflow/models/xcom.py | 10 +++++-----
 2 files changed, 20 insertions(+), 17 deletions(-)

diff --git a/airflow-core/src/airflow/models/dag.py 
b/airflow-core/src/airflow/models/dag.py
index 41123e80caa..2caf4888227 100644
--- a/airflow-core/src/airflow/models/dag.py
+++ b/airflow-core/src/airflow/models/dag.py
@@ -41,7 +41,7 @@ from sqlalchemy import (
 )
 from sqlalchemy.ext.associationproxy import association_proxy
 from sqlalchemy.ext.hybrid import hybrid_property
-from sqlalchemy.orm import Mapped, backref, load_only, relationship
+from sqlalchemy.orm import Mapped, Session, backref, load_only, relationship
 from sqlalchemy.sql import expression
 
 from airflow import settings
@@ -69,9 +69,6 @@ from airflow.utils.types import DagRunType
 if TYPE_CHECKING:
     from typing import TypeAlias
 
-    from sqlalchemy.orm.query import Query
-    from sqlalchemy.orm.session import Session
-
     from airflow.models.mappedoperator import MappedOperator
     from airflow.serialization.serialized_objects import 
SerializedBaseOperator, SerializedDAG
 
@@ -206,7 +203,7 @@ def _get_model_data_interval(
     return DataInterval(start, end)
 
 
-def get_last_dagrun(dag_id, session, include_manually_triggered=False):
+def get_last_dagrun(dag_id: str, session: Session, include_manually_triggered: 
bool = False) -> DagRun | None:
     """
     Return the last dag run for a dag, None if there was none.
 
@@ -300,7 +297,7 @@ class DagOwnerAttributes(Base):
         return f"<DagOwnerAttributes: dag_id={self.dag_id}, 
owner={self.owner}, link={self.link}>"
 
     @classmethod
-    def get_all(cls, session) -> dict[str, dict[str, str]]:
+    def get_all(cls, session: Session) -> dict[str, dict[str, str]]:
         dag_links: dict = defaultdict(dict)
         for obj in session.scalars(select(cls)):
             dag_links[obj.dag_id].update({obj.owner: obj.link})
@@ -494,11 +491,13 @@ class DagModel(Base):
 
     @classmethod
     @provide_session
-    def get_current(cls, dag_id: str, session=NEW_SESSION) -> DagModel:
+    def get_current(cls, dag_id: str, session: Session = NEW_SESSION) -> 
DagModel | None:
         return session.scalar(select(cls).where(cls.dag_id == dag_id))
 
     @provide_session
-    def get_last_dagrun(self, session=NEW_SESSION, 
include_manually_triggered=False):
+    def get_last_dagrun(
+        self, session: Session = NEW_SESSION, include_manually_triggered: bool 
= False
+    ) -> DagRun | None:
         return get_last_dagrun(
             self.dag_id, session=session, 
include_manually_triggered=include_manually_triggered
         )
@@ -585,7 +584,7 @@ class DagModel(Base):
         return any_deactivated
 
     @classmethod
-    def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str, 
datetime]]:
+    def dags_needing_dagruns(cls, session: Session) -> tuple[Any, dict[str, 
datetime]]:
         """
         Return (and lock) a list of Dag objects that are due to create a new 
DagRun.
 
@@ -706,7 +705,9 @@ class DagModel(Base):
         )
 
     @provide_session
-    def get_asset_triggered_next_run_info(self, *, session=NEW_SESSION) -> 
dict[str, int | str] | None:
+    def get_asset_triggered_next_run_info(
+        self, *, session: Session = NEW_SESSION
+    ) -> dict[str, int | str] | None:
         if self.asset_expression is None:
             return None
 
@@ -716,7 +717,7 @@ class DagModel(Base):
 
     @staticmethod
     @provide_session
-    def get_team_name(dag_id: str, session=NEW_SESSION) -> str | None:
+    def get_team_name(dag_id: str, session: Session = NEW_SESSION) -> str | 
None:
         """Return the team name associated to a Dag or None if it is not owned 
by a specific team."""
         stmt = (
             select(Team.name)
@@ -728,7 +729,9 @@ class DagModel(Base):
 
     @staticmethod
     @provide_session
-    def get_dag_id_to_team_name_mapping(dag_ids: list[str], 
session=NEW_SESSION) -> dict[str, str | None]:
+    def get_dag_id_to_team_name_mapping(
+        dag_ids: list[str], session: Session = NEW_SESSION
+    ) -> dict[str, str | None]:
         stmt = (
             select(DagModel.dag_id, Team.name)
             .join(DagBundleModel.teams)
diff --git a/airflow-core/src/airflow/models/xcom.py 
b/airflow-core/src/airflow/models/xcom.py
index 197e10ebb5e..b17c7a72448 100644
--- a/airflow-core/src/airflow/models/xcom.py
+++ b/airflow-core/src/airflow/models/xcom.py
@@ -21,7 +21,7 @@ import json
 import logging
 from collections.abc import Iterable
 from datetime import datetime
-from typing import TYPE_CHECKING, Any, cast
+from typing import TYPE_CHECKING, Any
 
 from sqlalchemy import (
     JSON,
@@ -235,7 +235,7 @@ class XComModel(TaskInstanceDependencies):
             )
         )
 
-        new = cast("Any", cls)(  # Work around Mypy complaining model not 
defining '__init__'.
+        new = cls(
             dag_run_id=dag_run_id,
             key=key,
             value=value,
@@ -258,7 +258,7 @@ class XComModel(TaskInstanceDependencies):
         map_indexes: int | Iterable[int] | None = None,
         include_prior_dates: bool = False,
         limit: int | None = None,
-    ) -> Select:
+    ) -> Select[tuple[XComModel]]:
         """
         Composes a query to get one or more XCom entries.
 
@@ -348,7 +348,7 @@ class XComModel(TaskInstanceDependencies):
             raise ValueError("XCom value must be JSON serializable")
 
     @staticmethod
-    def deserialize_value(result) -> Any:
+    def deserialize_value(result: Any) -> Any:
         """
         Deserialize XCom value from a database result.
 
@@ -397,7 +397,7 @@ class LazyXComSelectSequence(LazySelectSequence[Any]):
     """
 
     @staticmethod
-    def _rebuild_select(stmt: TextClause) -> Select:
+    def _rebuild_select(stmt: TextClause) -> Select[tuple[Any]]:
         return select(XComModel.value).from_statement(stmt)
 
     @staticmethod

Reply via email to