This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1103692e4bc Fix MyPy type errors in airflow_core models in
dag.py,xcom.py for Sqlalchemy 2 migration (#57323)
1103692e4bc is described below
commit 1103692e4bc05ebef4724c97d49190ce87ddf950
Author: Anusha Kovi <[email protected]>
AuthorDate: Mon Oct 27 18:55:28 2025 +0530
Fix MyPy type errors in airflow_core models in dag.py,xcom.py for
Sqlalchemy 2 migration (#57323)
---
airflow-core/src/airflow/models/dag.py | 27 +++++++++++++++------------
airflow-core/src/airflow/models/xcom.py | 10 +++++-----
2 files changed, 20 insertions(+), 17 deletions(-)
diff --git a/airflow-core/src/airflow/models/dag.py
b/airflow-core/src/airflow/models/dag.py
index 41123e80caa..2caf4888227 100644
--- a/airflow-core/src/airflow/models/dag.py
+++ b/airflow-core/src/airflow/models/dag.py
@@ -41,7 +41,7 @@ from sqlalchemy import (
)
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.ext.hybrid import hybrid_property
-from sqlalchemy.orm import Mapped, backref, load_only, relationship
+from sqlalchemy.orm import Mapped, Session, backref, load_only, relationship
from sqlalchemy.sql import expression
from airflow import settings
@@ -69,9 +69,6 @@ from airflow.utils.types import DagRunType
if TYPE_CHECKING:
from typing import TypeAlias
- from sqlalchemy.orm.query import Query
- from sqlalchemy.orm.session import Session
-
from airflow.models.mappedoperator import MappedOperator
from airflow.serialization.serialized_objects import
SerializedBaseOperator, SerializedDAG
@@ -206,7 +203,7 @@ def _get_model_data_interval(
return DataInterval(start, end)
-def get_last_dagrun(dag_id, session, include_manually_triggered=False):
+def get_last_dagrun(dag_id: str, session: Session, include_manually_triggered:
bool = False) -> DagRun | None:
"""
Return the last dag run for a dag, None if there was none.
@@ -300,7 +297,7 @@ class DagOwnerAttributes(Base):
return f"<DagOwnerAttributes: dag_id={self.dag_id},
owner={self.owner}, link={self.link}>"
@classmethod
- def get_all(cls, session) -> dict[str, dict[str, str]]:
+ def get_all(cls, session: Session) -> dict[str, dict[str, str]]:
dag_links: dict = defaultdict(dict)
for obj in session.scalars(select(cls)):
dag_links[obj.dag_id].update({obj.owner: obj.link})
@@ -494,11 +491,13 @@ class DagModel(Base):
@classmethod
@provide_session
- def get_current(cls, dag_id: str, session=NEW_SESSION) -> DagModel:
+ def get_current(cls, dag_id: str, session: Session = NEW_SESSION) ->
DagModel | None:
return session.scalar(select(cls).where(cls.dag_id == dag_id))
@provide_session
- def get_last_dagrun(self, session=NEW_SESSION,
include_manually_triggered=False):
+ def get_last_dagrun(
+ self, session: Session = NEW_SESSION, include_manually_triggered: bool
= False
+ ) -> DagRun | None:
return get_last_dagrun(
self.dag_id, session=session,
include_manually_triggered=include_manually_triggered
)
@@ -585,7 +584,7 @@ class DagModel(Base):
return any_deactivated
@classmethod
- def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str,
datetime]]:
+ def dags_needing_dagruns(cls, session: Session) -> tuple[Any, dict[str,
datetime]]:
"""
Return (and lock) a list of Dag objects that are due to create a new
DagRun.
@@ -706,7 +705,9 @@ class DagModel(Base):
)
@provide_session
- def get_asset_triggered_next_run_info(self, *, session=NEW_SESSION) ->
dict[str, int | str] | None:
+ def get_asset_triggered_next_run_info(
+ self, *, session: Session = NEW_SESSION
+ ) -> dict[str, int | str] | None:
if self.asset_expression is None:
return None
@@ -716,7 +717,7 @@ class DagModel(Base):
@staticmethod
@provide_session
- def get_team_name(dag_id: str, session=NEW_SESSION) -> str | None:
+ def get_team_name(dag_id: str, session: Session = NEW_SESSION) -> str |
None:
"""Return the team name associated to a Dag or None if it is not owned
by a specific team."""
stmt = (
select(Team.name)
@@ -728,7 +729,9 @@ class DagModel(Base):
@staticmethod
@provide_session
- def get_dag_id_to_team_name_mapping(dag_ids: list[str],
session=NEW_SESSION) -> dict[str, str | None]:
+ def get_dag_id_to_team_name_mapping(
+ dag_ids: list[str], session: Session = NEW_SESSION
+ ) -> dict[str, str | None]:
stmt = (
select(DagModel.dag_id, Team.name)
.join(DagBundleModel.teams)
diff --git a/airflow-core/src/airflow/models/xcom.py
b/airflow-core/src/airflow/models/xcom.py
index 197e10ebb5e..b17c7a72448 100644
--- a/airflow-core/src/airflow/models/xcom.py
+++ b/airflow-core/src/airflow/models/xcom.py
@@ -21,7 +21,7 @@ import json
import logging
from collections.abc import Iterable
from datetime import datetime
-from typing import TYPE_CHECKING, Any, cast
+from typing import TYPE_CHECKING, Any
from sqlalchemy import (
JSON,
@@ -235,7 +235,7 @@ class XComModel(TaskInstanceDependencies):
)
)
- new = cast("Any", cls)( # Work around Mypy complaining model not
defining '__init__'.
+ new = cls(
dag_run_id=dag_run_id,
key=key,
value=value,
@@ -258,7 +258,7 @@ class XComModel(TaskInstanceDependencies):
map_indexes: int | Iterable[int] | None = None,
include_prior_dates: bool = False,
limit: int | None = None,
- ) -> Select:
+ ) -> Select[tuple[XComModel]]:
"""
Composes a query to get one or more XCom entries.
@@ -348,7 +348,7 @@ class XComModel(TaskInstanceDependencies):
raise ValueError("XCom value must be JSON serializable")
@staticmethod
- def deserialize_value(result) -> Any:
+ def deserialize_value(result: Any) -> Any:
"""
Deserialize XCom value from a database result.
@@ -397,7 +397,7 @@ class LazyXComSelectSequence(LazySelectSequence[Any]):
"""
@staticmethod
- def _rebuild_select(stmt: TextClause) -> Select:
+ def _rebuild_select(stmt: TextClause) -> Select[tuple[Any]]:
return select(XComModel.value).from_statement(stmt)
@staticmethod