This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 5fef6984c61 Fix type hints (#61317)
5fef6984c61 is described below
commit 5fef6984c61fbf67d5567a4cdb4984662289212c
Author: Dev-iL <[email protected]>
AuthorDate: Wed Feb 4 17:09:24 2026 +0200
Fix type hints (#61317)
---
.../api_fastapi/auth/managers/base_auth_manager.py | 7 +++-
.../api_fastapi/core_api/routes/public/assets.py | 5 ++-
.../core_api/routes/public/dag_stats.py | 9 ++++-
.../api_fastapi/core_api/routes/public/dag_tags.py | 3 +-
.../api_fastapi/core_api/routes/public/xcom.py | 5 ++-
.../core_api/services/public/dag_run.py | 10 +++--
.../api_fastapi/execution_api/routes/dag_runs.py | 2 +-
.../api_fastapi/execution_api/routes/xcoms.py | 9 +++--
.../src/airflow/dag_processing/collection.py | 27 +++++++------
airflow-core/src/airflow/dag_processing/manager.py | 20 +++++-----
.../src/airflow/jobs/scheduler_job_runner.py | 5 +--
airflow-core/src/airflow/models/backfill.py | 3 +-
airflow-core/src/airflow/models/deadline.py | 5 ++-
airflow-core/src/airflow/models/pool.py | 3 +-
airflow-core/src/airflow/models/taskreschedule.py | 4 +-
.../airflow/serialization/definitions/deadline.py | 4 +-
.../src/airflow/serialization/definitions/node.py | 7 +++-
.../src/airflow/ti_deps/deps/trigger_rule_dep.py | 4 +-
airflow-core/src/airflow/typing_compat.py | 12 ++----
airflow-core/src/airflow/utils/sqlalchemy.py | 6 +--
airflow-core/tests/unit/models/test_dag.py | 3 +-
.../tests/unit/models/test_taskinstance.py | 4 +-
.../_internal_client/secret_manager_client.py | 4 +-
.../providers/google/common/hooks/base_google.py | 1 +
shared/dagnode/src/airflow_shared/dagnode/node.py | 44 ++++++++++++++++++++--
.../observability/metrics/datadog_logger.py | 4 +-
.../airflow_shared/observability/metrics/stats.py | 1 +
.../plugins_manager/plugins_manager.py | 7 +++-
.../providers_discovery/providers_discovery.py | 10 ++++-
.../src/airflow/sdk/definitions/_internal/node.py | 2 +-
30 files changed, 152 insertions(+), 78 deletions(-)
diff --git
a/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py
b/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py
index 5e98dc86fb0..8f9f1dd8e00 100644
--- a/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py
+++ b/airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py
@@ -47,6 +47,7 @@ from airflow.configuration import conf
from airflow.models import Connection, DagModel, Pool, Variable
from airflow.models.dagbundle import DagBundleModel
from airflow.models.team import Team, dag_bundle_team_association_table
+from airflow.typing_compat import Unpack
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
@@ -54,6 +55,7 @@ if TYPE_CHECKING:
from collections.abc import Sequence
from fastapi import FastAPI
+ from sqlalchemy import Row
from sqlalchemy.orm import Session
from airflow.api_fastapi.auth.managers.models.batch_apis import (
@@ -569,8 +571,9 @@ class BaseAuthManager(Generic[T], LoggingMixin,
metaclass=ABCMeta):
isouter=True,
)
)
- rows = session.execute(stmt).all()
- dags_by_team: dict[str | None, set[str]] = defaultdict(set)
+ # The below type annotation is acceptable on SQLA2.1, but not on 2.0
+ rows: Sequence[Row[Unpack[tuple[str, str]]]] =
session.execute(stmt).all() # type: ignore[type-arg]
+ dags_by_team: dict[str, set[str]] = defaultdict(set)
for dag_id, team_name in rows:
dags_by_team[team_name].add(dag_id)
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/assets.py
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/assets.py
index 0c23cc93a78..3e5108f1447 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/assets.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/assets.py
@@ -74,10 +74,12 @@ from airflow.models.asset import (
AssetWatcherModel,
TaskOutletAssetReference,
)
+from airflow.typing_compat import Unpack
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunTriggeredByType, DagRunType
if TYPE_CHECKING:
+ from sqlalchemy.engine import Result
from sqlalchemy.sql import Select
assets_router = AirflowRouter(tags=["Asset"])
@@ -179,7 +181,8 @@ def get_assets(
session=session,
)
- assets_rows = session.execute(
+ # The below type annotation is acceptable on SQLA2.1, but not on 2.0
+ assets_rows: Result[Unpack[tuple[AssetModel, int, datetime]]] =
session.execute( # type: ignore[type-arg]
assets_select.options(
subqueryload(AssetModel.scheduled_dags),
subqueryload(AssetModel.producing_tasks),
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_stats.py
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_stats.py
index d2b4cc17bf1..b42607369d3 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_stats.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_stats.py
@@ -17,7 +17,7 @@
from __future__ import annotations
-from typing import Annotated
+from typing import TYPE_CHECKING, Annotated
from fastapi import Depends, status
@@ -41,8 +41,12 @@ from airflow.api_fastapi.core_api.datamodels.dag_stats
import (
from airflow.api_fastapi.core_api.openapi.exceptions import
create_openapi_http_exception_doc
from airflow.api_fastapi.core_api.security import ReadableDagRunsFilterDep,
requires_access_dag
from airflow.models.dagrun import DagRun
+from airflow.typing_compat import Unpack
from airflow.utils.state import DagRunState
+if TYPE_CHECKING:
+ from sqlalchemy import Result
+
dag_stats_router = AirflowRouter(tags=["DagStats"], prefix="/dagStats")
@@ -71,7 +75,8 @@ def get_dag_stats(
session=session,
return_total_entries=False,
)
- query_result = session.execute(dagruns_select)
+ # The below type annotation is acceptable on SQLA2.1, but not on 2.0
+ query_result: Result[Unpack[tuple[str, str, str, int]]] =
session.execute(dagruns_select) # type: ignore[type-arg]
result_dag_ids = []
dag_display_names: dict[str, str] = {}
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_tags.py
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_tags.py
index b02b9be31ec..86ba73e69da 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_tags.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_tags.py
@@ -17,6 +17,7 @@
from __future__ import annotations
+from collections.abc import Sequence
from typing import Annotated
from fastapi import Depends
@@ -67,5 +68,5 @@ def get_dag_tags(
limit=limit,
session=session,
)
- dag_tags = session.execute(dag_tags_select).scalars().all()
+ dag_tags: Sequence = session.execute(dag_tags_select).scalars().all()
return DAGTagCollectionResponse(tags=[x for x in dag_tags],
total_entries=total_entries)
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/xcom.py
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/xcom.py
index 4ca9e342037..7bf64592aa6 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/xcom.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/xcom.py
@@ -93,10 +93,11 @@ def get_xcom_entry(
# We use `BaseXCom.get_many` to fetch XComs directly from the database,
bypassing the XCom Backend.
# This avoids deserialization via the backend (e.g., from a remote storage
like S3) and instead
# retrieves the raw serialized value from the database.
- result = session.scalars(xcom_query).first()
+ raw_result: tuple[XComModel] | None = session.scalars(xcom_query).first()
- if result is None:
+ if raw_result is None:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"XCom entry with key:
`{xcom_key}` not found")
+ result = raw_result[0] if isinstance(raw_result, tuple) else raw_result
item = copy.copy(result)
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/services/public/dag_run.py
b/airflow-core/src/airflow/api_fastapi/core_api/services/public/dag_run.py
index 5a08ed1c3b0..110f34c780e 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/services/public/dag_run.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/services/public/dag_run.py
@@ -35,6 +35,8 @@ from airflow.utils.state import State
if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Iterator
+ from sqlalchemy import ScalarResult
+
@attrs.define
class DagRunWaiter:
@@ -57,10 +59,12 @@ class DagRunWaiter:
task_ids=self.result_task_ids,
dag_ids=self.dag_id,
)
- xcom_results =
self.session.scalars(xcom_query.order_by(XComModel.task_id,
XComModel.map_index))
+ xcom_results: ScalarResult[tuple[XComModel]] = self.session.scalars(
+ xcom_query.order_by(XComModel.task_id, XComModel.map_index)
+ )
- def _group_xcoms(g: Iterator[XComModel]) -> Any:
- entries = list(g)
+ def _group_xcoms(g: Iterator[XComModel | tuple[XComModel]]) -> Any:
+ entries = [row[0] if isinstance(row, tuple) else row for row in g]
if len(entries) == 1 and entries[0].map_index < 0: # Unpack
non-mapped task xcom.
return entries[0].value
return [entry.value for entry in entries] # Task is mapped;
return all xcoms in a list.
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py
index 7763850b5ee..b3fd1cff7ee 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py
@@ -190,7 +190,7 @@ def get_dagrun_state(
) -> DagRunStateResponse:
"""Get a Dag run State."""
try:
- state = session.scalars(
+ state: DagRunState = session.scalars(
select(DagRunModel.state).where(DagRunModel.dag_id == dag_id,
DagRunModel.run_id == run_id)
).one()
except NoResultFound:
diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
index 3408513a8c8..ec77b64dc44 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
@@ -113,6 +113,7 @@ def get_mapped_xcom_by_index(
else:
xcom_query = xcom_query.order_by(XComModel.map_index.desc()).offset(-1
- offset)
+ result: tuple[XComModel] | None
if (result := session.scalars(xcom_query).first()) is None:
message = (
f"XCom with {key=} {offset=} not found for task {task_id!r} in DAG
run {run_id!r} of {dag_id!r}"
@@ -121,7 +122,7 @@ def get_mapped_xcom_by_index(
status_code=status.HTTP_404_NOT_FOUND,
detail={"reason": "not_found", "message": message},
)
- return XComSequenceIndexResponse(result.value)
+ return XComSequenceIndexResponse((result[0] if isinstance(result, tuple)
else result).value)
class GetXComSliceFilterParams(BaseModel):
@@ -291,8 +292,8 @@ def get_xcom(
# retrieves the raw serialized value from the database. By not relying on
`XCom.get_many` or `XCom.get_one`
# (which automatically deserializes using the backend), we avoid potential
# performance hits from retrieving large data files into the API server.
- result = session.scalars(xcom_query).first()
- if result is None:
+ result: tuple[XComModel] | None
+ if (result := session.scalars(xcom_query).first()) is None:
if params.offset is None:
message = (
f"XCom with {key=} map_index={params.map_index} not found for "
@@ -308,7 +309,7 @@ def get_xcom(
detail={"reason": "not_found", "message": message},
)
- return XComResponse(key=key, value=result.value)
+ return XComResponse(key=key, value=(result[0] if isinstance(result, tuple)
else result).value)
# TODO: once we have JWT tokens, then remove dag_id/run_id/task_id from the
URL and just use the info in
diff --git a/airflow-core/src/airflow/dag_processing/collection.py
b/airflow-core/src/airflow/dag_processing/collection.py
index f5ad10196b4..5abfbd8ae75 100644
--- a/airflow-core/src/airflow/dag_processing/collection.py
+++ b/airflow-core/src/airflow/dag_processing/collection.py
@@ -28,7 +28,7 @@ This should generally only be called by internal methods such
as
from __future__ import annotations
import traceback
-from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar, cast
+from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar
import structlog
from sqlalchemy import delete, func, insert, select, tuple_, update
@@ -76,7 +76,7 @@ if TYPE_CHECKING:
from sqlalchemy.sql import Select
from airflow.models.dagwarning import DagWarning
- from airflow.typing_compat import Self
+ from airflow.typing_compat import Self, Unpack
AssetT = TypeVar("AssetT", SerializedAsset, SerializedAssetAlias)
@@ -512,15 +512,18 @@ class DagModelOperation(NamedTuple):
def find_orm_dags(self, *, session: Session) -> dict[str, DagModel]:
"""Find existing DagModel objects from DAG objects."""
- stmt = (
- select(DagModel)
- .options(joinedload(DagModel.tags, innerjoin=False))
- .where(DagModel.dag_id.in_(self.dags))
- .options(joinedload(DagModel.schedule_asset_references))
- .options(joinedload(DagModel.schedule_asset_alias_references))
- .options(joinedload(DagModel.task_outlet_asset_references))
+ stmt: Select[Unpack[tuple[DagModel]]] = with_row_locks(
+ (
+ select(DagModel)
+ .options(joinedload(DagModel.tags, innerjoin=False))
+ .where(DagModel.dag_id.in_(self.dags))
+ .options(joinedload(DagModel.schedule_asset_references))
+ .options(joinedload(DagModel.schedule_asset_alias_references))
+ .options(joinedload(DagModel.task_outlet_asset_references))
+ ),
+ of=DagModel,
+ session=session,
)
- stmt = cast("Select[tuple[DagModel]]", with_row_locks(stmt,
of=DagModel, session=session))
return {dm.dag_id: dm for dm in session.scalars(stmt).unique()}
def add_dags(self, *, session: Session) -> dict[str, DagModel]:
@@ -711,7 +714,7 @@ def _find_all_asset_aliases(dags:
Iterable[LazyDeserializedDAG]) -> Iterator[Ser
def _find_active_assets(name_uri_assets: Iterable[tuple[str, str]], session:
Session) -> set[tuple[str, str]]:
return {
- tuple(row)
+ (str(row[0]), str(row[1]))
for row in session.execute(
select(AssetModel.name, AssetModel.uri).where(
tuple_(AssetModel.name, AssetModel.uri).in_(name_uri_assets),
@@ -906,7 +909,7 @@ class AssetModelOperation(NamedTuple):
if not references:
return
orm_refs = {
- tuple(row)
+ (str(row[0]), str(row[1]))
for row in session.execute(
select(model.dag_id, getattr(model, attr)).where(
model.dag_id.in_(dag_id for dag_id, _ in references)
diff --git a/airflow-core/src/airflow/dag_processing/manager.py
b/airflow-core/src/airflow/dag_processing/manager.py
index d77503a5885..81ae9afe577 100644
--- a/airflow-core/src/airflow/dag_processing/manager.py
+++ b/airflow-core/src/airflow/dag_processing/manager.py
@@ -75,7 +75,7 @@ from airflow.utils.session import NEW_SESSION,
create_session, provide_session
from airflow.utils.sqlalchemy import prohibit_commit, with_row_locks
if TYPE_CHECKING:
- from collections.abc import Callable, Iterable, Iterator
+ from collections.abc import Callable, Iterable, Iterator, Sequence
from socket import socket
from sqlalchemy.orm import Session
@@ -497,15 +497,17 @@ class DagFileProcessorManager(LoggingMixin):
callback_queue: list[CallbackRequest] = []
with prohibit_commit(session) as guard:
bundle_names = [bundle.name for bundle in self._dag_bundles]
- query: Select[tuple[DbCallbackRequest]] = select(DbCallbackRequest)
- query =
query.order_by(DbCallbackRequest.priority_weight.desc()).limit(
- self.max_callbacks_per_loop
- )
- query = cast(
- "Select[tuple[DbCallbackRequest]]",
- with_row_locks(query, of=DbCallbackRequest, session=session,
skip_locked=True),
+ query: Select[tuple[DbCallbackRequest]] = with_row_locks(
+ select(DbCallbackRequest)
+ .order_by(DbCallbackRequest.priority_weight.desc())
+ .limit(self.max_callbacks_per_loop),
+ of=DbCallbackRequest,
+ session=session,
+ skip_locked=True,
)
- callbacks = session.scalars(query)
+ callbacks: Sequence[DbCallbackRequest] = [
+ cb[0] if isinstance(cb, tuple) else cb for cb in
session.scalars(query)
+ ]
for callback in callbacks:
req = callback.get_callback_request()
if req.bundle_name not in bundle_names:
diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
index 0dc18ca696a..96b865f75bb 100644
--- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
+++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
@@ -36,7 +36,6 @@ from typing import TYPE_CHECKING, Any
from sqlalchemy import (
and_,
delete,
- desc,
exists,
func,
inspect,
@@ -2578,7 +2577,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
Log.try_number == ti.try_number,
Log.event == "running",
)
- .order_by(desc(Log.dttm))
+ .order_by(Log.dttm.desc())
.limit(1)
)
@@ -2652,7 +2651,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
@provide_session
def _emit_running_dags_metric(self, session: Session = NEW_SESSION) ->
None:
stmt = select(func.count()).select_from(DagRun).where(DagRun.state ==
DagRunState.RUNNING)
- running_dags = float(session.scalar(stmt))
+ running_dags = float(session.scalar(stmt) or 0)
Stats.gauge("scheduler.dagruns.running", running_dags)
@provide_session
diff --git a/airflow-core/src/airflow/models/backfill.py
b/airflow-core/src/airflow/models/backfill.py
index 365e6c9b225..64828bdc10f 100644
--- a/airflow-core/src/airflow/models/backfill.py
+++ b/airflow-core/src/airflow/models/backfill.py
@@ -35,7 +35,6 @@ from sqlalchemy import (
Integer,
String,
UniqueConstraint,
- desc,
func,
select,
)
@@ -229,7 +228,7 @@ def _get_latest_dag_run_row_query(*, dag_id: str, info:
DagRunInfo, session: Ses
DagRun.logical_date == info.logical_date,
DagRun.dag_id == dag_id,
)
- .order_by(nulls_first(desc(DagRun.start_date), session=session))
+ .order_by(nulls_first(DagRun.start_date.desc(), session=session))
.limit(1)
)
diff --git a/airflow-core/src/airflow/models/deadline.py
b/airflow-core/src/airflow/models/deadline.py
index 11985a42c5a..070304f30a7 100644
--- a/airflow-core/src/airflow/models/deadline.py
+++ b/airflow-core/src/airflow/models/deadline.py
@@ -18,6 +18,7 @@ from __future__ import annotations
import logging
from abc import ABC, abstractmethod
+from collections.abc import Sequence
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, cast
@@ -185,7 +186,7 @@ class Deadline(Base):
dagruns_to_refresh = set()
for deadline, dagrun in deadline_dagrun_pairs:
- if dagrun.end_date <= deadline.deadline_time:
+ if dagrun.end_date is not None and dagrun.end_date <=
deadline.deadline_time:
# If the DagRun finished before the Deadline:
session.delete(deadline)
Stats.incr(
@@ -403,7 +404,7 @@ class ReferenceModels:
query = query.limit(self.max_runs)
# Get all durations and calculate average
- durations = session.execute(query).scalars().all()
+ durations: Sequence = session.execute(query).scalars().all()
if len(durations) < cast("int", self.min_runs):
logger.info(
diff --git a/airflow-core/src/airflow/models/pool.py
b/airflow-core/src/airflow/models/pool.py
index a00e8b90b58..dab8862a12c 100644
--- a/airflow-core/src/airflow/models/pool.py
+++ b/airflow-core/src/airflow/models/pool.py
@@ -191,7 +191,8 @@ class Pool(Base):
pools: dict[str, PoolStats] = {}
pool_includes_deferred: dict[str, bool] = {}
- query: Select[Any] = select(Pool.pool, Pool.slots,
Pool.include_deferred)
+ # The below type annotation is acceptable on SQLA2.1, but not on 2.0
+ query: Select[str, int, bool] = select(Pool.pool, Pool.slots,
Pool.include_deferred) # type: ignore[type-arg]
if lock_rows:
query = with_row_locks(query, session=session, nowait=True)
diff --git a/airflow-core/src/airflow/models/taskreschedule.py
b/airflow-core/src/airflow/models/taskreschedule.py
index 005b6846845..88f87121fd7 100644
--- a/airflow-core/src/airflow/models/taskreschedule.py
+++ b/airflow-core/src/airflow/models/taskreschedule.py
@@ -28,8 +28,6 @@ from sqlalchemy import (
Index,
Integer,
String,
- asc,
- desc,
select,
)
from sqlalchemy.dialects import postgresql
@@ -94,4 +92,4 @@ class TaskReschedule(Base):
:param descending: If True then records are returned in descending
order
:meta private:
"""
- return select(cls).where(cls.ti_id == ti.id).order_by(desc(cls.id) if
descending else asc(cls.id))
+ return select(cls).where(cls.ti_id == ti.id).order_by(cls.id.desc() if
descending else cls.id.asc())
diff --git a/airflow-core/src/airflow/serialization/definitions/deadline.py
b/airflow-core/src/airflow/serialization/definitions/deadline.py
index 2fefbcb86af..78adc6b9a76 100644
--- a/airflow-core/src/airflow/serialization/definitions/deadline.py
+++ b/airflow-core/src/airflow/serialization/definitions/deadline.py
@@ -32,6 +32,8 @@ from airflow.utils.session import provide_session
from airflow.utils.sqlalchemy import get_dialect_name
if TYPE_CHECKING:
+ from collections.abc import Sequence
+
from sqlalchemy import ColumnElement
from sqlalchemy.orm import Session
@@ -210,7 +212,7 @@ class SerializedReferenceModels:
.limit(self.max_runs)
)
- durations = list(session.execute(query).scalars())
+ durations: Sequence = session.execute(query).scalars().all()
min_runs = self.min_runs or 0
if len(durations) < min_runs:
diff --git a/airflow-core/src/airflow/serialization/definitions/node.py
b/airflow-core/src/airflow/serialization/definitions/node.py
index 06c61a54de5..2cbdc9db771 100644
--- a/airflow-core/src/airflow/serialization/definitions/node.py
+++ b/airflow-core/src/airflow/serialization/definitions/node.py
@@ -32,11 +32,16 @@ if TYPE_CHECKING:
__all__ = ["DAGNode"]
-class DAGNode(GenericDAGNode["SerializedDAG", "Operator",
"SerializedTaskGroup"], metaclass=abc.ABCMeta):
+class DAGNode(GenericDAGNode["SerializedDAG", "Operator",
"SerializedTaskGroup"], metaclass=abc.ABCMeta): # type: ignore[type-var]
"""
Base class for a node in the graph of a workflow.
A node may be an operator or task group, either mapped or unmapped.
+
+ Note: type: ignore is used because SerializedBaseOperator and
SerializedTaskGroup
+ don't have explicit type annotations for all attributes required by
TaskProtocol
+ and TaskGroupProtocol (they inherit them from GenericDAGNode). This is
acceptable
+ because they are implemented correctly at runtime.
"""
@property
diff --git a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
index 893807fa599..5d2b6955d75 100644
--- a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
+++ b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
@@ -40,6 +40,7 @@ if TYPE_CHECKING:
from airflow.serialization.definitions.taskgroup import
SerializedMappedTaskGroup
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.deps.base_ti_dep import TIDepStatus
+ from airflow.typing_compat import Unpack
class _UpstreamTIStates(NamedTuple):
@@ -371,7 +372,8 @@ class TriggerRuleDep(BaseTIDep):
upstream = len(upstream_tasks)
upstream_setup = sum(1 for x in upstream_tasks.values() if
x.is_setup)
else:
- task_id_counts: Sequence[Row[tuple[str, int]]] =
session.execute(
+ # The below type annotation is acceptable on SQLA2.1, but not
on 2.0
+ task_id_counts: Sequence[Row[Unpack[tuple[str, int]]]] =
session.execute( # type: ignore[type-arg]
select(TaskInstance.task_id,
func.count(TaskInstance.task_id))
.where(TaskInstance.dag_id == ti.dag_id,
TaskInstance.run_id == ti.run_id)
.where(or_(*_iter_upstream_conditions(relevant_tasks=upstream_tasks)))
diff --git a/airflow-core/src/airflow/typing_compat.py
b/airflow-core/src/airflow/typing_compat.py
index 8a00ac06bd7..e1efb87067b 100644
--- a/airflow-core/src/airflow/typing_compat.py
+++ b/airflow-core/src/airflow/typing_compat.py
@@ -19,13 +19,7 @@
from __future__ import annotations
-__all__ = [
- "Literal",
- "ParamSpec",
- "Self",
- "TypeAlias",
- "TypeGuard",
-]
+__all__ = ["Literal", "ParamSpec", "Self", "TypeAlias", "TypeGuard", "Unpack"]
import sys
@@ -33,6 +27,6 @@ import sys
from typing import Literal, ParamSpec, TypeAlias, TypeGuard
if sys.version_info >= (3, 11):
- from typing import Self
+ from typing import Self, Unpack
else:
- from typing_extensions import Self
+ from typing_extensions import Self, Unpack
diff --git a/airflow-core/src/airflow/utils/sqlalchemy.py
b/airflow-core/src/airflow/utils/sqlalchemy.py
index 8d9e826bef7..266be08c3bb 100644
--- a/airflow-core/src/airflow/utils/sqlalchemy.py
+++ b/airflow-core/src/airflow/utils/sqlalchemy.py
@@ -22,7 +22,7 @@ import copy
import datetime
import logging
from collections.abc import Generator
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING
from sqlalchemy import TIMESTAMP, PickleType, event, nullsfirst
from sqlalchemy.dialects import mysql
@@ -319,14 +319,14 @@ USE_ROW_LEVEL_LOCKING: bool =
conf.getboolean("scheduler", "use_row_level_lockin
def with_row_locks(
- query: Select[Any],
+ query: Select,
session: Session,
*,
nowait: bool = False,
skip_locked: bool = False,
key_share: bool = True,
**kwargs,
-) -> Select[Any]:
+) -> Select:
"""
Apply with_for_update to the SQLAlchemy query if row level locking is in
use.
diff --git a/airflow-core/tests/unit/models/test_dag.py
b/airflow-core/tests/unit/models/test_dag.py
index 2a45735847e..9683a500d95 100644
--- a/airflow-core/tests/unit/models/test_dag.py
+++ b/airflow-core/tests/unit/models/test_dag.py
@@ -118,6 +118,7 @@ from unit.plugins.priority_weight_strategy import (
)
if TYPE_CHECKING:
+ from sqlalchemy.engine import ScalarResult
from sqlalchemy.orm import Session
pytestmark = pytest.mark.db_test
@@ -2361,7 +2362,7 @@ class TestDagModel:
)
SerializedDAG.bulk_write_to_db("testing", None, [dag], session=session)
- expression = session.scalars(
+ expression: ScalarResult = session.scalars(
select(DagModel.asset_expression).where(DagModel.dag_id ==
dag.dag_id)
).one()
assert expression == {
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py
b/airflow-core/tests/unit/models/test_taskinstance.py
index 08dffc78dc5..1b6b3a01d1b 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -1733,7 +1733,9 @@ class TestTaskInstance:
for ti in dr.get_task_instances(session=session):
run_task_instance(ti, dag_maker.dag.get_task(ti.task_id),
session=session)
- events = dict((tuple(row)) for row in
session.execute(select(AssetEvent.source_task_id, AssetEvent)))
+ events: dict[str, AssetEvent] = dict(
+ (str(row[0]), row[1]) for row in
session.execute(select(AssetEvent.source_task_id, AssetEvent))
+ )
assert set(events) == {"write1", "write2"}
assert events["write1"].source_dag_id == dr.dag_id
diff --git
a/providers/google/src/airflow/providers/google/cloud/_internal_client/secret_manager_client.py
b/providers/google/src/airflow/providers/google/cloud/_internal_client/secret_manager_client.py
index a78d0c7bdb2..0ddf01d99fb 100644
---
a/providers/google/src/airflow/providers/google/cloud/_internal_client/secret_manager_client.py
+++
b/providers/google/src/airflow/providers/google/cloud/_internal_client/secret_manager_client.py
@@ -27,7 +27,7 @@ from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.utils.log.logging_mixin import LoggingMixin
if TYPE_CHECKING:
- import google
+ from google.auth.credentials import Credentials
SECRET_ID_PATTERN = r"^[a-zA-Z0-9-_]*$"
@@ -45,7 +45,7 @@ class _SecretManagerClient(LoggingMixin):
def __init__(
self,
- credentials: google.auth.credentials.Credentials,
+ credentials: Credentials,
) -> None:
super().__init__()
self.credentials = credentials
diff --git
a/providers/google/src/airflow/providers/google/common/hooks/base_google.py
b/providers/google/src/airflow/providers/google/common/hooks/base_google.py
index b55dce384eb..0a7a0805a4f 100644
--- a/providers/google/src/airflow/providers/google/common/hooks/base_google.py
+++ b/providers/google/src/airflow/providers/google/common/hooks/base_google.py
@@ -718,6 +718,7 @@ class _CredentialsToken(Token):
super().__init__(session=cast("Session", session), scopes=_scopes)
self.credentials = credentials
self.project = project
+ self.acquiring: asyncio.Task[None] | None = None
@classmethod
async def from_hook(
diff --git a/shared/dagnode/src/airflow_shared/dagnode/node.py
b/shared/dagnode/src/airflow_shared/dagnode/node.py
index 0fed0c97f75..7d52ff1ea1f 100644
--- a/shared/dagnode/src/airflow_shared/dagnode/node.py
+++ b/shared/dagnode/src/airflow_shared/dagnode/node.py
@@ -17,18 +17,54 @@
from __future__ import annotations
-from typing import TYPE_CHECKING, Generic, TypeVar
+from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar
import structlog
if TYPE_CHECKING:
+ import sys
from collections.abc import Collection, Iterable
+ # Replicate `airflow.typing_compat.Self` to avoid illegal imports
+ if sys.version_info >= (3, 11):
+ from typing import Self
+ else:
+ from typing_extensions import Self
+
from ..logging.types import Logger
-Dag = TypeVar("Dag")
-Task = TypeVar("Task")
-TaskGroup = TypeVar("TaskGroup")
+
+class DagProtocol(Protocol):
+ """Protocol defining the minimum interface required for Dag generic
type."""
+
+ dag_id: str
+ task_dict: dict[str, Any]
+
+ def get_task(self, tid: str) -> Any:
+ """Retrieve a task by its task ID."""
+ ...
+
+
+class TaskProtocol(Protocol):
+ """Protocol defining the minimum interface required for Task generic
type."""
+
+ task_id: str
+ is_setup: bool
+ is_teardown: bool
+ downstream_list: Iterable[Self]
+ downstream_task_ids: set[str]
+
+
+class TaskGroupProtocol(Protocol):
+ """Protocol defining the minimum interface required for TaskGroup generic
type."""
+
+ node_id: str
+ prefix_group_id: bool
+
+
+Dag = TypeVar("Dag", bound=DagProtocol)
+Task = TypeVar("Task", bound=TaskProtocol)
+TaskGroup = TypeVar("TaskGroup", bound=TaskGroupProtocol)
class GenericDAGNode(Generic[Dag, Task, TaskGroup]):
diff --git
a/shared/observability/src/airflow_shared/observability/metrics/datadog_logger.py
b/shared/observability/src/airflow_shared/observability/metrics/datadog_logger.py
index 595e6c8a33f..09ac6b15a71 100644
---
a/shared/observability/src/airflow_shared/observability/metrics/datadog_logger.py
+++
b/shared/observability/src/airflow_shared/observability/metrics/datadog_logger.py
@@ -20,7 +20,7 @@ from __future__ import annotations
import datetime
import logging
from collections.abc import Callable
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
from .protocols import Timer
from .validators import (
@@ -176,7 +176,7 @@ def get_dogstatsd_logger(
"""Get DataDog StatsD logger."""
from datadog import DogStatsd
- dogstatsd_kwargs: dict[str, str | int | list[str]] = {
+ dogstatsd_kwargs: dict[str, Any] = {
"constant_tags": cls.get_constant_tags(tags_in_string=tags_in_string),
}
if host is not None:
diff --git
a/shared/observability/src/airflow_shared/observability/metrics/stats.py
b/shared/observability/src/airflow_shared/observability/metrics/stats.py
index e2c3e63077d..e477a751ec0 100644
--- a/shared/observability/src/airflow_shared/observability/metrics/stats.py
+++ b/shared/observability/src/airflow_shared/observability/metrics/stats.py
@@ -56,6 +56,7 @@ class _Stats(type):
def initialize(cls, *, is_statsd_datadog_enabled: bool, is_statsd_on:
bool, is_otel_on: bool) -> None:
type.__setattr__(cls, "factory", None)
type.__setattr__(cls, "instance", None)
+ factory: Callable
if is_statsd_datadog_enabled:
from airflow.observability.metrics import datadog_logger
diff --git
a/shared/plugins_manager/src/airflow_shared/plugins_manager/plugins_manager.py
b/shared/plugins_manager/src/airflow_shared/plugins_manager/plugins_manager.py
index 9ea497e5a10..8fcc5c9c808 100644
---
a/shared/plugins_manager/src/airflow_shared/plugins_manager/plugins_manager.py
+++
b/shared/plugins_manager/src/airflow_shared/plugins_manager/plugins_manager.py
@@ -19,6 +19,9 @@
from __future__ import annotations
+import importlib
+import importlib.machinery
+import importlib.util
import inspect
import logging
import os
@@ -208,8 +211,6 @@ def _load_plugins_from_plugin_directory(
ignore_file_syntax: str = "glob",
) -> tuple[list[AirflowPlugin], dict[str, str]]:
"""Load and register Airflow Plugins from plugins directory."""
- import importlib
-
from ..module_loading import find_path_from_directory
if not plugins_folder:
@@ -219,6 +220,8 @@ def _load_plugins_from_plugin_directory(
plugin_search_locations: list[tuple[str, Generator[str, None, None]]] =
[("", files)]
if load_examples:
+ if not example_plugins_module:
+ raise ValueError("example_plugins_module is required when
load_examples is True")
log.debug("Note: Loading plugins from examples as well: %s",
plugins_folder)
example_plugins = importlib.import_module(example_plugins_module)
example_plugins_folder = next(iter(example_plugins.__path__))
diff --git
a/shared/providers_discovery/src/airflow_shared/providers_discovery/providers_discovery.py
b/shared/providers_discovery/src/airflow_shared/providers_discovery/providers_discovery.py
index 4fc882d1b5d..dcab0fe3034 100644
---
a/shared/providers_discovery/src/airflow_shared/providers_discovery/providers_discovery.py
+++
b/shared/providers_discovery/src/airflow_shared/providers_discovery/providers_discovery.py
@@ -27,7 +27,7 @@ from dataclasses import dataclass
from functools import wraps
from importlib.resources import files as resource_files
from time import perf_counter
-from typing import Any, NamedTuple, ParamSpec
+from typing import Any, NamedTuple, ParamSpec, Protocol, cast
import structlog
from packaging.utils import canonicalize_name
@@ -43,6 +43,12 @@ PS = ParamSpec("PS")
KNOWN_UNHANDLED_OPTIONAL_FEATURE_ERRORS = [("apache-airflow-providers-google",
"No module named 'paramiko'")]
+class ProvidersManagerProtocol(Protocol):
+ """Protocol for ProvidersManager for type checking purposes."""
+
+ _initialized_cache: dict[str, bool]
+
+
@dataclass
class ProviderInfo:
"""
@@ -271,7 +277,7 @@ def provider_info_cache(cache_name: str) ->
Callable[[Callable[PS, None]], Calla
def provider_info_cache_decorator(func: Callable[PS, None]) ->
Callable[PS, None]:
@wraps(func)
def wrapped_function(*args: PS.args, **kwargs: PS.kwargs) -> None:
- instance = args[0]
+ instance = cast("ProvidersManagerProtocol", args[0])
if cache_name in instance._initialized_cache:
return
diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/node.py
b/task-sdk/src/airflow/sdk/definitions/_internal/node.py
index 4ee812b103d..803ca837825 100644
--- a/task-sdk/src/airflow/sdk/definitions/_internal/node.py
+++ b/task-sdk/src/airflow/sdk/definitions/_internal/node.py
@@ -64,7 +64,7 @@ def validate_group_key(k: str, max_length: int = 200):
)
-class DAGNode(GenericDAGNode["DAG", "Operator", "TaskGroup"], DependencyMixin,
metaclass=ABCMeta):
+class DAGNode(GenericDAGNode["DAG", "Operator", "TaskGroup"], DependencyMixin,
metaclass=ABCMeta): # type: ignore[type-var]
"""
A base class for a node in the graph of a workflow.