This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 6775bf7bae1 Make `ExternalTaskSensor` work with Task SDK (#48651)
6775bf7bae1 is described below
commit 6775bf7bae13f4291e18d4118179c14e4444de0d
Author: Kaxil Naik <[email protected]>
AuthorDate: Thu Apr 3 13:57:49 2025 +0530
Make `ExternalTaskSensor` work with Task SDK (#48651)
closes https://github.com/apache/airflow/issues/47447
closes https://github.com/apache/airflow/issues/47948
---
.../api_fastapi/execution_api/routes/dag_runs.py | 30 +-
.../execution_api/routes/task_instances.py | 108 ++++-
airflow-core/tests/conftest.py | 11 -
.../execution_api/versions/head/test_dag_runs.py | 98 ++++-
.../versions/head/test_task_instances.py | 165 +++++++
devel-common/src/tests_common/pytest_plugin.py | 11 +
.../providers/standard/sensors/external_task.py | 128 ++++--
.../providers/standard/utils/sensor_helper.py | 9 +-
.../standard}/sensors/test_external_task_sensor.py | 490 ++++++++++++++++-----
task-sdk/src/airflow/sdk/api/client.py | 48 ++
task-sdk/src/airflow/sdk/execution_time/comms.py | 42 +-
.../src/airflow/sdk/execution_time/supervisor.py | 20 +
.../src/airflow/sdk/execution_time/task_runner.py | 60 +++
task-sdk/src/airflow/sdk/types.py | 18 +
task-sdk/tests/task_sdk/api/test_client.py | 94 ++++
.../task_sdk/execution_time/test_supervisor.py | 34 ++
.../task_sdk/execution_time/test_task_runner.py | 52 +++
17 files changed, 1261 insertions(+), 157 deletions(-)
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py
index 3a680c1ef8c..a4398601138 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py
@@ -18,13 +18,15 @@
from __future__ import annotations
import logging
+from typing import Annotated
-from fastapi import HTTPException, status
-from sqlalchemy import select
+from fastapi import HTTPException, Query, status
+from sqlalchemy import func, select
from airflow.api.common.trigger_dag import trigger_dag
from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.router import AirflowRouter
+from airflow.api_fastapi.common.types import UtcDateTime
from airflow.api_fastapi.execution_api.datamodels.dagrun import
DagRunStateResponse, TriggerDAGRunPayload
from airflow.exceptions import DagRunAlreadyExists
from airflow.models.dag import DagModel
@@ -150,3 +152,27 @@ def get_dagrun_state(
)
return DagRunStateResponse(state=dag_run.state)
+
+
[email protected]("/count", status_code=status.HTTP_200_OK)
+def get_dr_count(
+ dag_id: str,
+ session: SessionDep,
+ logical_dates: Annotated[list[UtcDateTime] | None, Query()] = None,
+ run_ids: Annotated[list[str] | None, Query()] = None,
+ states: Annotated[list[str] | None, Query()] = None,
+) -> int:
+ """Get the count of DAG runs matching the given criteria."""
+ query = select(func.count()).select_from(DagRun).where(DagRun.dag_id ==
dag_id)
+
+ if logical_dates:
+ query = query.where(DagRun.logical_date.in_(logical_dates))
+
+ if run_ids:
+ query = query.where(DagRun.run_id.in_(run_ids))
+
+ if states:
+ query = query.where(DagRun.state.in_(states))
+
+ count = session.scalar(query)
+ return count or 0
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index ff0ca516314..6dd0732c077 100644
---
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -23,13 +23,14 @@ from typing import Annotated
from uuid import UUID
from cadwyn import VersionedAPIRouter
-from fastapi import Body, Depends, HTTPException, status
+from fastapi import Body, Depends, HTTPException, Query, status
from pydantic import JsonValue
-from sqlalchemy import func, tuple_, update
+from sqlalchemy import func, or_, tuple_, update
from sqlalchemy.exc import NoResultFound, SQLAlchemyError
from sqlalchemy.sql import select
from airflow.api_fastapi.common.db.common import SessionDep
+from airflow.api_fastapi.common.types import UtcDateTime
from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
PrevSuccessfulDagRunResponse,
TIDeferredStatePayload,
@@ -45,6 +46,7 @@ from
airflow.api_fastapi.execution_api.datamodels.taskinstance import (
TITerminalStatePayload,
)
from airflow.api_fastapi.execution_api.deps import JWTBearer
+from airflow.models.dagbag import DagBag
from airflow.models.dagrun import DagRun as DR
from airflow.models.taskinstance import TaskInstance as TI, _update_rtif
from airflow.models.taskreschedule import TaskReschedule
@@ -53,7 +55,9 @@ from airflow.models.xcom import XComModel
from airflow.utils import timezone
from airflow.utils.state import DagRunState, TaskInstanceState
-router = VersionedAPIRouter(
+router = VersionedAPIRouter()
+
+ti_id_router = VersionedAPIRouter(
dependencies=[
# This checks that the UUID in the url matches the one in the token
for us.
Depends(JWTBearer(path_param_name="task_instance_id")),
@@ -64,7 +68,7 @@ router = VersionedAPIRouter(
log = logging.getLogger(__name__)
[email protected](
+@ti_id_router.patch(
"/{task_instance_id}/run",
status_code=status.HTTP_200_OK,
responses={
@@ -243,7 +247,7 @@ def ti_run(
)
[email protected](
+@ti_id_router.patch(
"/{task_instance_id}/state",
status_code=status.HTTP_204_NO_CONTENT,
responses={
@@ -404,7 +408,7 @@ def ti_update_state(
)
[email protected](
+@ti_id_router.patch(
"/{task_instance_id}/skip-downstream",
status_code=status.HTTP_204_NO_CONTENT,
responses={
@@ -436,7 +440,7 @@ def ti_skip_downstream(
log.info("TI %s updated the state of %s task(s) to skipped", ti_id_str,
result.rowcount)
[email protected](
+@ti_id_router.put(
"/{task_instance_id}/heartbeat",
status_code=status.HTTP_204_NO_CONTENT,
responses={
@@ -498,7 +502,7 @@ def ti_heartbeat(
log.debug("Task with %s state heartbeated", previous_state)
[email protected](
+@ti_id_router.put(
"/{task_instance_id}/rtif",
status_code=status.HTTP_201_CREATED,
# TODO: Add description to the operation
@@ -528,7 +532,7 @@ def ti_put_rtif(
return {"message": "Rendered task instance fields successfully set"}
[email protected](
+@ti_id_router.get(
"/{task_instance_id}/previous-successful-dagrun",
status_code=status.HTTP_200_OK,
responses={
@@ -564,8 +568,86 @@ def get_previous_successful_dagrun(
return PrevSuccessfulDagRunResponse.model_validate(dag_run)
[email protected]_exists_in_older_versions
[email protected](
[email protected]("/count", status_code=status.HTTP_200_OK)
+def get_count(
+ dag_id: str,
+ session: SessionDep,
+ task_ids: Annotated[list[str] | None, Query()] = None,
+ task_group_id: Annotated[str | None, Query()] = None,
+ logical_dates: Annotated[list[UtcDateTime] | None, Query()] = None,
+ run_ids: Annotated[list[str] | None, Query()] = None,
+ states: Annotated[list[str] | None, Query()] = None,
+) -> int:
+ """Get the count of task instances matching the given criteria."""
+ query = select(func.count()).select_from(TI).where(TI.dag_id == dag_id)
+
+ if task_ids:
+ query = query.where(TI.task_id.in_(task_ids))
+
+ if logical_dates:
+ query = query.where(TI.logical_date.in_(logical_dates))
+
+ if run_ids:
+ query = query.where(TI.run_id.in_(run_ids))
+
+ if task_group_id:
+ # Get all tasks in the task group
+ dag = DagBag(read_dags_from_db=True).get_dag(dag_id, session)
+ if not dag:
+ raise HTTPException(
+ status.HTTP_404_NOT_FOUND,
+ detail={
+ "reason": "not_found",
+ "message": f"DAG {dag_id} not found",
+ },
+ )
+
+ task_group = dag.task_group_dict.get(task_group_id)
+ if not task_group:
+ raise HTTPException(
+ status.HTTP_404_NOT_FOUND,
+ detail={
+ "reason": "not_found",
+ "message": f"Task group {task_group_id} not found in DAG
{dag_id}",
+ },
+ )
+
+ # First get all task instances to get the task_id, map_index pairs
+ group_tasks = session.scalars(
+ select(TI).where(
+ TI.dag_id == dag_id,
+ TI.task_id.in_(task.task_id for task in
task_group.iter_tasks()),
+ *([TI.logical_date.in_(logical_dates)] if logical_dates else
[]),
+ *([TI.run_id.in_(run_ids)] if run_ids else []),
+ )
+ ).all()
+
+ # Get unique (task_id, map_index) pairs
+ task_map_pairs = [(ti.task_id, ti.map_index) for ti in group_tasks]
+ if not task_map_pairs:
+ # If no task group tasks found, default to checking the task group
ID itself
+ # This matches the behavior in _get_external_task_group_task_ids
+ task_map_pairs = [(task_group_id, -1)]
+
+ # Update query to use task_id, map_index pairs
+ query = query.where(tuple_(TI.task_id,
TI.map_index).in_(task_map_pairs))
+
+ if states:
+ if "null" in states:
+ not_none_states = [s for s in states if s != "null"]
+ if not_none_states:
+ query = query.where(or_(TI.state.is_(None),
TI.state.in_(not_none_states)))
+ else:
+ query = query.where(TI.state.is_(None))
+ else:
+ query = query.where(TI.state.in_(states))
+
+ count = session.scalar(query)
+ return count or 0
+
+
+@ti_id_router.only_exists_in_older_versions
+@ti_id_router.post(
"/{task_instance_id}/runtime-checks",
status_code=status.HTTP_204_NO_CONTENT,
# TODO: Add description to the operation
@@ -602,3 +684,7 @@ def _is_eligible_to_retry(state: str, try_number: int,
max_tries: int) -> bool:
# max_tries is initialised with the retries defined at task level, we do
not need to explicitly ask for
# retries from the task SDK now, we can handle using max_tries
return max_tries != 0 and try_number <= max_tries
+
+
+# This line should be at the end of the file to ensure all routes are
registered
+router.include_router(ti_id_router)
diff --git a/airflow-core/tests/conftest.py b/airflow-core/tests/conftest.py
index c5affb469f4..c605cc7648e 100644
--- a/airflow-core/tests/conftest.py
+++ b/airflow-core/tests/conftest.py
@@ -78,17 +78,6 @@ def clear_all_logger_handlers():
remove_all_non_pytest_log_handlers()
[email protected]
-def testing_dag_bundle():
- from airflow.models.dagbundle import DagBundleModel
- from airflow.utils.session import create_session
-
- with create_session() as session:
- if session.query(DagBundleModel).filter(DagBundleModel.name ==
"testing").count() == 0:
- testing = DagBundleModel(name="testing")
- session.add(testing)
-
-
@contextmanager
def _config_bundles(bundles: dict[str, Path | str]):
from tests_common.test_utils.config import conf_vars
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py
index c5624f188b5..f9f8d489d3d 100644
---
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py
+++
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py
@@ -23,7 +23,7 @@ from airflow.models import DagModel
from airflow.models.dagrun import DagRun
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.utils import timezone
-from airflow.utils.state import DagRunState
+from airflow.utils.state import DagRunState, State
from tests_common.test_utils.db import clear_db_runs
@@ -218,3 +218,99 @@ class TestDagRunState:
response = client.post(f"/execution/dag-runs/{dag_id}/{run_id}/clear")
assert response.status_code == 404
+
+
+class TestGetDagRunCount:
+ def setup_method(self):
+ clear_db_runs()
+
+ def teardown_method(self):
+ clear_db_runs()
+
+ def test_get_count_basic(self, client, session, dag_maker):
+ with dag_maker("test_dag"):
+ pass
+ dag_maker.create_dagrun()
+ session.commit()
+
+ response = client.get("/execution/dag-runs/count", params={"dag_id":
"test_dag"})
+ assert response.status_code == 200
+ assert response.json() == 1
+
+ def test_get_count_with_states(self, client, session, dag_maker):
+ """Test counting DAG runs in specific states."""
+ with dag_maker("test_get_count_with_states"):
+ pass
+
+ # Create DAG runs with different states
+ dag_maker.create_dagrun(
+ state=State.SUCCESS, logical_date=timezone.datetime(2025, 1, 1),
run_id="test_run_id1"
+ )
+ dag_maker.create_dagrun(
+ state=State.FAILED, logical_date=timezone.datetime(2025, 1, 2),
run_id="test_run_id2"
+ )
+ dag_maker.create_dagrun(
+ state=State.RUNNING, logical_date=timezone.datetime(2025, 1, 3),
run_id="test_run_id3"
+ )
+ session.commit()
+
+ response = client.get(
+ "/execution/dag-runs/count",
+ params={"dag_id": "test_get_count_with_states", "states":
[State.SUCCESS, State.FAILED]},
+ )
+ assert response.status_code == 200
+ assert response.json() == 2
+
+ def test_get_count_with_logical_dates(self, client, session, dag_maker):
+ with dag_maker("test_get_count_with_logical_dates"):
+ pass
+
+ date1 = timezone.datetime(2025, 1, 1)
+ date2 = timezone.datetime(2025, 1, 2)
+
+ dag_maker.create_dagrun(run_id="test_run_id1", logical_date=date1)
+ dag_maker.create_dagrun(run_id="test_run_id2", logical_date=date2)
+ session.commit()
+
+ response = client.get(
+ "/execution/dag-runs/count",
+ params={
+ "dag_id": "test_get_count_with_logical_dates",
+ "logical_dates": [date1.isoformat(), date2.isoformat()],
+ },
+ )
+ assert response.status_code == 200
+ assert response.json() == 2
+
+ def test_get_count_with_run_ids(self, client, session, dag_maker):
+ with dag_maker("test_get_count_with_run_ids"):
+ pass
+
+ dag_maker.create_dagrun(run_id="run1",
logical_date=timezone.datetime(2025, 1, 1))
+ dag_maker.create_dagrun(run_id="run2",
logical_date=timezone.datetime(2025, 1, 2))
+ session.commit()
+
+ response = client.get(
+ "/execution/dag-runs/count",
+ params={"dag_id": "test_get_count_with_run_ids", "run_ids":
["run1", "run2"]},
+ )
+ assert response.status_code == 200
+ assert response.json() == 2
+
+ def test_get_count_with_mixed_states(self, client, session, dag_maker):
+ with dag_maker("test_get_count_with_mixed"):
+ pass
+ dag_maker.create_dagrun(
+ state=State.SUCCESS, run_id="runid1",
logical_date=timezone.datetime(2025, 1, 1)
+ )
+ dag_maker.create_dagrun(
+ state=State.QUEUED, run_id="runid2",
logical_date=timezone.datetime(2025, 1, 2)
+ )
+ session.commit()
+
+ response = client.get(
+ "/execution/dag-runs/count",
+ params={"dag_id": "test_get_count_with_mixed", "states":
[State.SUCCESS, State.QUEUED]},
+ )
+ assert response.status_code == 200
+ assert response.json() == 2
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
index 81c99b271be..147209e967a 100644
---
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
+++
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
@@ -33,6 +33,7 @@ from airflow.models.asset import AssetActive,
AssetAliasModel, AssetEvent, Asset
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancehistory import TaskInstanceHistory
from airflow.providers.standard.operators.empty import EmptyOperator
+from airflow.sdk import TaskGroup
from airflow.utils import timezone
from airflow.utils.state import State, TaskInstanceState, TerminalTIState
@@ -1223,3 +1224,167 @@ class TestGetRescheduleStartDate:
response =
client.get(f"/execution/task-reschedules/{ti.id}/start_date?try_number=2")
assert response.status_code == 200
assert response.json() == "2024-01-02T00:00:00Z"
+
+
+class TestGetCount:
+ def setup_method(self):
+ clear_db_runs()
+
+ def teardown_method(self):
+ clear_db_runs()
+
+ def test_get_count_basic(self, client, session, create_task_instance):
+ create_task_instance(task_id="test_task", state=State.SUCCESS)
+ session.commit()
+
+ response = client.get("/execution/task-instances/count",
params={"dag_id": "dag"})
+ assert response.status_code == 200
+ assert response.json() == 1
+
+ def test_get_count_with_task_ids(self, client, session,
create_task_instance):
+ for i in range(3):
+ create_task_instance(
+ task_id=f"task{i}",
+ state=State.SUCCESS,
+ dag_id="test_get_count_with_task_ids",
+ run_id=f"test_run_id{i}",
+ )
+ session.commit()
+
+ response = client.get(
+ "/execution/task-instances/count",
+ params={"dag_id": "test_get_count_with_task_ids", "task_ids":
["task1", "task2"]},
+ )
+ assert response.status_code == 200
+ assert response.json() == 2
+
+ def test_get_count_with_states(self, client, session, dag_maker):
+ """Test counting tasks in specific states."""
+ with dag_maker("test_get_count_with_states"):
+ EmptyOperator(task_id="task1")
+ EmptyOperator(task_id="task2")
+ EmptyOperator(task_id="task3")
+
+ dr = dag_maker.create_dagrun()
+
+ tis = dr.get_task_instances()
+
+ # Set different states for the task instances
+ for ti, state in zip(tis, [State.SUCCESS, State.FAILED,
State.SKIPPED]):
+ ti.state = state
+ session.merge(ti)
+ session.commit()
+
+ response = client.get(
+ "/execution/task-instances/count",
+ params={"dag_id": "test_get_count_with_states", "states":
[State.SUCCESS, State.FAILED]},
+ )
+ assert response.status_code == 200
+ assert response.json() == 2
+
+ def test_get_count_with_logical_dates(self, client, session, dag_maker):
+ with dag_maker("test_get_count_with_logical_dates"):
+ EmptyOperator(task_id="task1")
+
+ date1 = timezone.datetime(2025, 1, 1)
+ date2 = timezone.datetime(2025, 1, 2)
+
+ dag_maker.create_dagrun(run_id="test_run_id1", logical_date=date1)
+ dag_maker.create_dagrun(run_id="test_run_id2", logical_date=date2)
+
+ session.commit()
+
+ response = client.get(
+ "/execution/task-instances/count",
+ params={
+ "dag_id": "test_get_count_with_logical_dates",
+ "logical_dates": [date1.isoformat(), date2.isoformat()],
+ },
+ )
+ assert response.status_code == 200
+ assert response.json() == 2
+
+ def test_get_count_with_run_ids(self, client, session, dag_maker):
+ with dag_maker("test_get_count_with_run_ids"):
+ EmptyOperator(task_id="task1")
+
+ dag_maker.create_dagrun(run_id="run1",
logical_date=timezone.datetime(2025, 1, 1))
+ dag_maker.create_dagrun(run_id="run2",
logical_date=timezone.datetime(2025, 1, 2))
+
+ session.commit()
+
+ response = client.get(
+ "/execution/task-instances/count",
+ params={"dag_id": "test_get_count_with_run_ids", "run_ids":
["run1", "run2"]},
+ )
+ assert response.status_code == 200
+ assert response.json() == 2
+
+ def test_get_count_with_task_group(self, client, session, dag_maker):
+ with dag_maker(dag_id="test_dag", serialized=True):
+ with TaskGroup("group1"):
+ EmptyOperator(task_id="task1")
+ EmptyOperator(task_id="task2")
+
+ with TaskGroup("group2"):
+ EmptyOperator(task_id="task3")
+
+ dag_maker.create_dagrun(session=session)
+ session.commit()
+
+ response = client.get(
+ "/execution/task-instances/count",
+ params={"dag_id": "test_dag", "task_group_id": "group1"},
+ )
+ assert response.status_code == 200
+ assert response.json() == 2
+
+ def test_get_count_task_group_not_found(self, client, session, dag_maker):
+ with dag_maker(dag_id="test_get_count_task_group_not_found",
serialized=True):
+ with TaskGroup("group1"):
+ EmptyOperator(task_id="task1")
+ dag_maker.create_dagrun(session=session)
+
+ response = client.get(
+ "/execution/task-instances/count",
+ params={"dag_id": "test_get_count_task_group_not_found",
"task_group_id": "non_existent_group"},
+ )
+ assert response.status_code == 404
+ assert response.json()["detail"] == {
+ "reason": "not_found",
+ "message": "Task group non_existent_group not found in DAG
test_get_count_task_group_not_found",
+ }
+
+ def test_get_count_dag_not_found(self, client, session):
+ response = client.get(
+ "/execution/task-instances/count",
+ params={"dag_id": "non_existent_dag", "task_group_id": "group1"},
+ )
+ assert response.status_code == 404
+ assert response.json()["detail"] == {
+ "reason": "not_found",
+ "message": "DAG non_existent_dag not found",
+ }
+
+ def test_get_count_with_none_state(self, client, session,
create_task_instance):
+ create_task_instance(task_id="task1", dag_id="get_count_with_none",
state=None)
+ session.commit()
+
+ response = client.get(
+ "/execution/task-instances/count",
+ params={"dag_id": "get_count_with_none", "states": ["null"]},
+ )
+ assert response.status_code == 200
+ assert response.json() == 1
+
+ def test_get_count_with_mixed_states(self, client, session,
create_task_instance):
+ create_task_instance(task_id="task1", state=State.SUCCESS,
run_id="runid1", dag_id="mixed_states")
+ create_task_instance(task_id="task2", state=None, run_id="runid2",
dag_id="mixed_states")
+ session.commit()
+
+ response = client.get(
+ "/execution/task-instances/count",
+ params={"dag_id": "mixed_states", "states": [State.SUCCESS,
"null"]},
+ )
+ assert response.status_code == 200
+ assert response.json() == 2
diff --git a/devel-common/src/tests_common/pytest_plugin.py
b/devel-common/src/tests_common/pytest_plugin.py
index 3cc66ac9f96..b92c27141d7 100644
--- a/devel-common/src/tests_common/pytest_plugin.py
+++ b/devel-common/src/tests_common/pytest_plugin.py
@@ -2318,3 +2318,14 @@ def run_task(create_runtime_ti, mock_supervisor_comms,
spy_agency) -> RunTaskCal
def mock_xcom_backend():
with mock.patch("airflow.sdk.execution_time.task_runner.XCom",
create=True) as xcom_backend:
yield xcom_backend
+
+
[email protected]
+def testing_dag_bundle():
+ from airflow.models.dagbundle import DagBundleModel
+ from airflow.utils.session import create_session
+
+ with create_session() as session:
+ if session.query(DagBundleModel).filter(DagBundleModel.name ==
"testing").count() == 0:
+ testing = DagBundleModel(name="testing")
+ session.add(testing)
diff --git
a/providers/standard/src/airflow/providers/standard/sensors/external_task.py
b/providers/standard/src/airflow/providers/standard/sensors/external_task.py
index 0a34cc5d48f..dd5ad6c4e7a 100644
--- a/providers/standard/src/airflow/providers/standard/sensors/external_task.py
+++ b/providers/standard/src/airflow/providers/standard/sensors/external_task.py
@@ -25,27 +25,32 @@ from typing import TYPE_CHECKING, Any, Callable, ClassVar
from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowSkipException
-from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import DagModel
from airflow.models.dagbag import DagBag
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.providers.standard.triggers.external_task import WorkflowTrigger
from airflow.providers.standard.utils.sensor_helper import _get_count,
_get_external_task_group_task_ids
from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
-from airflow.sensors.base import BaseSensorOperator
from airflow.utils.file import correct_maybe_zipped
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import State, TaskInstanceState
+if AIRFLOW_V_3_0_PLUS:
+ from airflow.sdk.bases.sensor import BaseSensorOperator
+else:
+ from airflow.sensors.base import BaseSensorOperator
+
if TYPE_CHECKING:
from sqlalchemy.orm import Session
from airflow.models.taskinstancekey import TaskInstanceKey
try:
+ from airflow.sdk import BaseOperator
from airflow.sdk.definitions.context import Context
except ImportError:
# TODO: Remove once provider drops support for Airflow 2
+ from airflow.models.baseoperator import BaseOperator
from airflow.utils.context import Context
@@ -65,15 +70,16 @@ class ExternalDagLink(BaseOperatorLink):
name = "External DAG"
def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) ->
str:
- from airflow.models.renderedtifields import RenderedTaskInstanceFields
-
if TYPE_CHECKING:
assert isinstance(operator, (ExternalTaskMarker,
ExternalTaskSensor))
- if template_fields :=
RenderedTaskInstanceFields.get_templated_fields(ti_key):
- external_dag_id: str = template_fields.get("external_dag_id",
operator.external_dag_id)
- else:
- external_dag_id = operator.external_dag_id
+ external_dag_id = operator.external_dag_id
+
+ if not AIRFLOW_V_3_0_PLUS:
+ from airflow.models.renderedtifields import
RenderedTaskInstanceFields
+
+ if template_fields :=
RenderedTaskInstanceFields.get_templated_fields(ti_key):
+ external_dag_id: str = template_fields.get("external_dag_id",
operator.external_dag_id) # type: ignore[no-redef]
if AIRFLOW_V_3_0_PLUS:
from airflow.utils.helpers import build_airflow_dagrun_url
@@ -86,9 +92,7 @@ class ExternalDagLink(BaseOperatorLink):
return build_airflow_url_with_query(query)
-# TODO: Remove BaseOperator from inheritance in
https://github.com/apache/airflow/issues/47447
-# It is only temporary until we refactor the code to not directly go to the
DB.
-class ExternalTaskSensor(BaseSensorOperator, BaseOperator):
+class ExternalTaskSensor(BaseSensorOperator):
"""
Waits for a different DAG, task group, or task to complete for a specific
logical date.
@@ -247,16 +251,22 @@ class ExternalTaskSensor(BaseSensorOperator,
BaseOperator):
self.poll_interval = poll_interval
def _get_dttm_filter(self, context):
+ logical_date = context.get("logical_date")
+ if logical_date is None:
+ dag_run = context.get("dag_run")
+ if TYPE_CHECKING:
+ assert dag_run
+
+ logical_date = dag_run.run_after
if self.execution_delta:
- dttm = context["logical_date"] - self.execution_delta
+ dttm = logical_date - self.execution_delta
elif self.execution_date_fn:
dttm = self._handle_execution_date_fn(context=context)
else:
- dttm = context["logical_date"]
+ dttm = logical_date
return dttm if isinstance(dttm, list) else [dttm]
- @provide_session
- def poke(self, context: Context, session: Session = NEW_SESSION) -> bool:
+ def poke(self, context: Context) -> bool:
# delay check to poke rather than __init__ in case it was supplied as
XComArgs
if self.external_task_ids and len(self.external_task_ids) >
len(set(self.external_task_ids)):
raise ValueError("Duplicate task_ids passed in external_task_ids
parameter")
@@ -287,15 +297,62 @@ class ExternalTaskSensor(BaseSensorOperator,
BaseOperator):
serialized_dttm_filter,
)
- # In poke mode this will check dag existence only once
- if self.check_existence and not self._has_checked_existence:
- self._check_for_existence(session=session)
+ if AIRFLOW_V_3_0_PLUS:
+ return self._poke_af3(context, dttm_filter)
+ else:
+ return self._poke_af2(dttm_filter)
+
+ def _poke_af3(self, context: Context, dttm_filter:
list[datetime.datetime]) -> bool:
+ self._has_checked_existence = True
+ ti = context["ti"]
+
+ def _get_count(states: list[str]) -> int:
+ if self.external_task_ids:
+ return ti.get_ti_count(
+ dag_id=self.external_dag_id,
+ task_ids=self.external_task_ids, # type: ignore[arg-type]
+ logical_dates=dttm_filter,
+ states=states,
+ )
+ elif self.external_task_group_id:
+ return ti.get_ti_count(
+ dag_id=self.external_dag_id,
+ task_group_id=self.external_task_group_id,
+ logical_dates=dttm_filter,
+ states=states,
+ )
+ else:
+ return ti.get_dr_count(
+ dag_id=self.external_dag_id,
+ logical_dates=dttm_filter,
+ states=states,
+ )
- count_failed = -1
if self.failed_states:
- count_failed = self.get_count(dttm_filter, session,
self.failed_states)
+ count = _get_count(self.failed_states)
+ count_failed = self._calculate_count(count, dttm_filter)
+ self._handle_failed_states(count_failed)
- # Fail if anything in the list has failed.
+ if self.skipped_states:
+ count = _get_count(self.skipped_states)
+ count_skipped = self._calculate_count(count, dttm_filter)
+ self._handle_skipped_states(count_skipped)
+
+ count = _get_count(self.allowed_states)
+ count_allowed = self._calculate_count(count, dttm_filter)
+ return count_allowed == len(dttm_filter)
+
+ def _calculate_count(self, count: int, dttm_filter:
list[datetime.datetime]) -> float | int:
+ """Calculate the normalized count based on the type of check."""
+ if self.external_task_ids:
+ return count / len(self.external_task_ids)
+ elif self.external_task_group_id:
+ return count / len(dttm_filter)
+ else:
+ return count
+
+ def _handle_failed_states(self, count_failed: float | int) -> None:
+ """Handle failed states and raise appropriate exceptions."""
if count_failed > 0:
if self.external_task_ids:
if self.soft_fail:
@@ -317,7 +374,6 @@ class ExternalTaskSensor(BaseSensorOperator, BaseOperator):
f"The external task_group '{self.external_task_group_id}' "
f"in DAG '{self.external_dag_id}' failed."
)
-
else:
if self.soft_fail:
raise AirflowSkipException(
@@ -325,12 +381,8 @@ class ExternalTaskSensor(BaseSensorOperator, BaseOperator):
)
raise AirflowException(f"The external DAG
{self.external_dag_id} failed.")
- count_skipped = -1
- if self.skipped_states:
- count_skipped = self.get_count(dttm_filter, session,
self.skipped_states)
-
- # Skip if anything in the list has skipped. Note if we are checking
multiple tasks and one skips
- # before another errors, we'll skip first.
+ def _handle_skipped_states(self, count_skipped: float | int) -> None:
+ """Handle skipped states and raise appropriate exceptions."""
if count_skipped > 0:
if self.external_task_ids:
raise AirflowSkipException(
@@ -348,7 +400,19 @@ class ExternalTaskSensor(BaseSensorOperator, BaseOperator):
"Skipping."
)
- # only go green if every single task has reached an allowed state
+ @provide_session
+ def _poke_af2(self, dttm_filter: list[datetime.datetime], session: Session
= NEW_SESSION) -> bool:
+ if self.check_existence and not self._has_checked_existence:
+ self._check_for_existence(session=session)
+
+ if self.failed_states:
+ count_failed = self.get_count(dttm_filter, session,
self.failed_states)
+ self._handle_failed_states(count_failed)
+
+ if self.skipped_states:
+ count_skipped = self.get_count(dttm_filter, session,
self.skipped_states)
+ self._handle_skipped_states(count_skipped)
+
count_allowed = self.get_count(dttm_filter, session,
self.allowed_states)
return count_allowed == len(dttm_filter)
@@ -483,6 +547,9 @@ class ExternalTaskMarker(EmptyOperator):
"""
template_fields = ["external_dag_id", "external_task_id", "logical_date"]
+ if not AIRFLOW_V_3_0_PLUS:
+ template_fields.append("execution_date")
+
ui_color = "#4db7db"
operator_extra_links = [ExternalDagLink()]
@@ -510,6 +577,9 @@ class ExternalTaskMarker(EmptyOperator):
f"Expected str or datetime.datetime type for logical_date. Got
{type(logical_date)}"
)
+ if not AIRFLOW_V_3_0_PLUS:
+ self.execution_date = self.logical_date
+
if recursion_depth <= 0:
raise ValueError("recursion_depth should be a positive integer")
self.recursion_depth = recursion_depth
diff --git
a/providers/standard/src/airflow/providers/standard/utils/sensor_helper.py
b/providers/standard/src/airflow/providers/standard/utils/sensor_helper.py
index ae5d3c12985..17d54e371bc 100644
--- a/providers/standard/src/airflow/providers/standard/utils/sensor_helper.py
+++ b/providers/standard/src/airflow/providers/standard/utils/sensor_helper.py
@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, cast
from sqlalchemy import func, select, tuple_
from airflow.models import DagBag, DagRun, TaskInstance
+from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.utils.session import NEW_SESSION, provide_session
if TYPE_CHECKING:
@@ -88,8 +89,10 @@ def _count_stmt(model, states, dttm_filter, external_dag_id)
-> Executable:
:param dttm_filter: date time filter for logical date
:param external_dag_id: The ID of the external DAG.
"""
+ date_field = model.logical_date if AIRFLOW_V_3_0_PLUS else
model.execution_date
+
return select(func.count()).where(
- model.dag_id == external_dag_id, model.state.in_(states),
model.logical_date.in_(dttm_filter)
+ model.dag_id == external_dag_id, model.state.in_(states),
date_field.in_(dttm_filter)
)
@@ -106,11 +109,13 @@ def _get_external_task_group_task_ids(dttm_filter,
external_task_group_id, exter
task_group = refreshed_dag_info.task_group_dict.get(external_task_group_id)
if task_group:
+ date_field = TaskInstance.logical_date if AIRFLOW_V_3_0_PLUS else
TaskInstance.execution_date
+
group_tasks = session.scalars(
select(TaskInstance).filter(
TaskInstance.dag_id == external_dag_id,
TaskInstance.task_id.in_(task.task_id for task in task_group),
- TaskInstance.logical_date.in_(dttm_filter),
+ date_field.in_(dttm_filter),
)
)
diff --git a/airflow-core/tests/unit/sensors/test_external_task_sensor.py
b/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py
similarity index 79%
rename from airflow-core/tests/unit/sensors/test_external_task_sensor.py
rename to
providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py
index 1a7938cc27d..95deddd4ade 100644
--- a/airflow-core/tests/unit/sensors/test_external_task_sensor.py
+++
b/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py
@@ -19,10 +19,7 @@ from __future__ import annotations
import itertools
import logging
-import os
import re
-import tempfile
-import zipfile
from datetime import time, timedelta
from unittest import mock
@@ -47,8 +44,7 @@ from airflow.providers.standard.sensors.time import TimeSensor
from airflow.providers.standard.triggers.external_task import WorkflowTrigger
from airflow.serialization.serialized_objects import SerializedBaseOperator
from airflow.timetables.base import DataInterval
-from airflow.utils.hashlib_wrapper import md5
-from airflow.utils.session import NEW_SESSION, create_session, provide_session
+from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.task_group import TaskGroup
from airflow.utils.timezone import coerce_datetime, datetime
@@ -57,7 +53,6 @@ from airflow.utils.types import DagRunType
from tests_common.test_utils.db import clear_db_runs
from tests_common.test_utils.mock_operators import MockOperator
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
-from unit.models import TEST_DAGS_FOLDER
if AIRFLOW_V_3_0_PLUS:
from airflow.utils.types import DagRunTriggeredByType
@@ -81,42 +76,13 @@ def clean_db():
clear_db_runs()
[email protected]
-def dag_zip_maker(testing_dag_bundle):
- class DagZipMaker:
- def __call__(self, *dag_files):
- self.__dag_files = [os.sep.join([TEST_DAGS_FOLDER.__str__(),
dag_file]) for dag_file in dag_files]
- dag_files_hash =
md5("".join(self.__dag_files).encode()).hexdigest()
- self.__tmp_dir = os.sep.join([tempfile.tempdir, dag_files_hash])
-
- self.__zip_file_name = os.sep.join([self.__tmp_dir,
f"{dag_files_hash}.zip"])
-
- if not os.path.exists(self.__tmp_dir):
- os.mkdir(self.__tmp_dir)
- return self
-
- def __enter__(self):
- with zipfile.ZipFile(self.__zip_file_name, "x") as zf:
- for dag_file in self.__dag_files:
- zf.write(dag_file, os.path.basename(dag_file))
- dagbag = DagBag(dag_folder=self.__tmp_dir, include_examples=False)
- dagbag.sync_to_db("testing", None)
- return dagbag
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- os.unlink(self.__zip_file_name)
- os.rmdir(self.__tmp_dir)
-
- return DagZipMaker()
-
-
[email protected]("testing_dag_bundle")
-class TestExternalTaskSensor:
[email protected](AIRFLOW_V_3_0_PLUS, reason="Different test for v3.0+")
+class TestExternalTaskSensorV2:
def setup_method(self):
self.dagbag = DagBag(dag_folder=DEV_NULL, include_examples=True)
self.args = {"owner": "airflow", "start_date": DEFAULT_DATE}
self.dag = DAG(TEST_DAG_ID, schedule=None, default_args=self.args)
- self.dag_run_id =
DagRunType.MANUAL.generate_run_id(suffix=DEFAULT_DATE.isoformat())
+ self.dag_run_id = DagRunType.MANUAL.generate_run_id(DEFAULT_DATE)
def add_time_sensor(self, task_id=TEST_TASK_ID):
# TODO: Remove BaseOperator in
https://github.com/apache/airflow/issues/47447
@@ -133,7 +99,7 @@ class TestExternalTaskSensor:
with TaskGroup(group_id=TEST_TASK_GROUP_ID) as task_group:
_ = [EmptyOperator(task_id=f"task{i}") for i in
range(len(target_states))]
dag.sync_to_db()
- SerializedDagModel.write_dag(dag, bundle_name="test_bundle")
+ SerializedDagModel.write_dag(dag)
for idx, task in enumerate(task_group):
ti = TaskInstance(task=task, run_id=self.dag_run_id)
@@ -156,7 +122,7 @@ class TestExternalTaskSensor:
fake_task()
fake_mapped_task.expand(x=list(map_indexes))
dag.sync_to_db()
- SerializedDagModel.write_dag(dag, bundle_name="test_bundle")
+ SerializedDagModel.write_dag(dag)
for task in task_group:
if task.task_id == "fake_mapped_task":
@@ -530,7 +496,7 @@ exit 0
.filter(
TI.dag_id == dag_external_id,
TI.state == State.FAILED,
- TI.logical_date == DEFAULT_DATE + timedelta(seconds=1),
+ TI.execution_date == DEFAULT_DATE + timedelta(seconds=1),
)
.all()
)
@@ -977,10 +943,301 @@ exit 0
check_existence=True,
**kwargs,
)
+ if not hasattr(op, "never_fail"):
+ expected_message = "Skipping due to soft_fail is set to True." if
soft_fail else expected_message
with pytest.raises(expected_exception, match=expected_message):
op.execute(context={})
[email protected](not AIRFLOW_V_3_0_PLUS, reason="Different test for AF 2")
[email protected]("testing_dag_bundle")
+class TestExternalTaskSensorV3:
+ def setup_method(self):
+ # Create a mock for TaskInstance with get_ti_count method
+ mock_ti = mock.MagicMock()
+ mock_ti.get_ti_count = mock.MagicMock(return_value=0) # Default
return value
+
+ self.context = {
+ "execution_date": DEFAULT_DATE,
+ "logical_date": DEFAULT_DATE,
+ "ti": mock_ti,
+ "task": mock.MagicMock(),
+ "run_id": "test_run_id",
+ }
+
+ @pytest.mark.execution_timeout(10)
+ def test_external_task_sensor_success(self, dag_maker):
+ """Test that the sensor succeeds when the external task succeeds."""
+ with dag_maker("test_dag_child"):
+ op = ExternalTaskSensor(
+ task_id="test_external_task_sensor_check",
+ external_dag_id="test_dag_parent",
+ external_task_id="test_external_task_sensor_success",
+ allowed_states=["success"],
+ )
+
+ # Mimic DB response to get_ti_count as 1
+ self.context["ti"].get_ti_count.return_value = 1
+
+ op.execute(context=self.context)
+
+ self.context["ti"].get_ti_count.assert_called_once_with(
+ dag_id="test_dag_parent",
+ logical_dates=[DEFAULT_DATE],
+ states=["success"],
+ task_ids=["test_external_task_sensor_success"],
+ )
+
+ @pytest.mark.execution_timeout(10)
+ def test_external_task_sensor_failure(self, dag_maker):
+ """Test that the sensor fails when the external task fails."""
+ with dag_maker("test_dag_child"):
+ op = ExternalTaskSensor(
+ task_id="test_external_task_sensor_check",
+ external_dag_id="test_dag_parent",
+ external_task_id="test_external_task_sensor_failure",
+ failed_states=[State.FAILED],
+ )
+
+ self.context["ti"].get_ti_count.return_value = 1
+
+ with pytest.raises(AirflowException):
+ op.execute(context=self.context)
+
+ self.context["ti"].get_ti_count.assert_called_once_with(
+ dag_id="test_dag_parent",
+ logical_dates=[DEFAULT_DATE],
+ states=[State.FAILED],
+ task_ids=["test_external_task_sensor_failure"],
+ )
+
+ @pytest.mark.execution_timeout(10)
+ def test_external_task_sensor_soft_fail(self, dag_maker):
+ """Test that the sensor skips when soft_fail is True and external task
fails."""
+ with dag_maker("test_dag_child"):
+ op = ExternalTaskSensor(
+ task_id="test_external_task_sensor_check",
+ external_dag_id="test_dag_parent",
+ external_task_id="test_external_task_sensor_soft_fail",
+ failed_states=[State.FAILED],
+ soft_fail=True,
+ )
+
+ self.context["ti"].get_ti_count.return_value = 1
+
+ with pytest.raises(AirflowSkipException):
+ op.execute(context=self.context)
+
+ self.context["ti"].get_ti_count.assert_called_once_with(
+ dag_id="test_dag_parent",
+ logical_dates=[DEFAULT_DATE],
+ states=[State.FAILED],
+ task_ids=["test_external_task_sensor_soft_fail"],
+ )
+
+ @pytest.mark.execution_timeout(10)
+ def test_external_task_sensor_multiple_task_ids(self, dag_maker):
+ with dag_maker("test_dag_child"):
+ op = ExternalTaskSensor(
+ task_id="test_external_task_sensor_check",
+ external_dag_id="test_dag_parent",
+ external_task_ids=["task1", "task2"],
+ allowed_states=["success"],
+ )
+
+ self.context["ti"].get_ti_count.return_value = 2
+ op.execute(context=self.context)
+
+ self.context["ti"].get_ti_count.assert_called_once_with(
+ dag_id="test_dag_parent",
+ logical_dates=[DEFAULT_DATE],
+ states=["success"],
+ task_ids=["task1", "task2"],
+ )
+
+ @pytest.mark.execution_timeout(10)
+ def test_external_task_sensor_skipped_states(self, dag_maker):
+ with dag_maker("test_dag_child"):
+ op = ExternalTaskSensor(
+ task_id="test_external_task_sensor_check",
+ external_dag_id="test_dag_parent",
+ external_task_id="test_external_task_sensor_skipped_states",
+ skipped_states=[State.SKIPPED],
+ )
+
+ self.context["ti"].get_ti_count.return_value = 1
+ with pytest.raises(AirflowSkipException):
+ op.execute(context=self.context)
+
+ self.context["ti"].get_ti_count.assert_called_once_with(
+ dag_id="test_dag_parent",
+ logical_dates=[DEFAULT_DATE],
+ states=[State.SKIPPED],
+ task_ids=["test_external_task_sensor_skipped_states"],
+ )
+
+ def test_external_task_sensor_invalid_combination(self, dag_maker):
+ """Test that the sensor raises an error with invalid parameter
combinations."""
+ with pytest.raises(ValueError):
+ with dag_maker("test_external_task_sensor_invalid_combination"):
+ ExternalTaskSensor(
+ task_id="test_external_task_sensor_check",
+ external_dag_id="test_dag",
+ external_task_id="test_task",
+ external_task_ids=["test_task"],
+ )
+
+ def test_external_task_sensor_invalid_state(self, dag_maker):
+ with pytest.raises(ValueError):
+ with dag_maker("test_external_task_sensor_invalid_state"):
+ ExternalTaskSensor(
+ task_id="test_external_task_sensor_check",
+ external_dag_id="test_dag",
+ external_task_id="test_task",
+ allowed_states=["invalid_state"],
+ )
+
+ @pytest.mark.execution_timeout(10)
+ def test_external_task_sensor_task_group(self, dag_maker):
+ with dag_maker("test_dag_child"):
+ op = ExternalTaskSensor(
+ task_id="test_external_task_sensor_check",
+ external_dag_id="test_dag_parent",
+ external_task_group_id="test_group",
+ allowed_states=["success"],
+ )
+
+ self.context["ti"].get_ti_count.return_value = 1
+ op.execute(context=self.context)
+
+ self.context["ti"].get_ti_count.assert_called_once_with(
+ dag_id="test_dag_parent",
+ logical_dates=[DEFAULT_DATE],
+ states=["success"],
+ task_group_id="test_group",
+ )
+
+ @pytest.mark.execution_timeout(10)
+ def test_external_task_sensor_execution_date_fn(self, dag_maker):
+ def execution_date_fn(dt):
+ return [dt + timedelta(hours=1)]
+
+ with dag_maker("test_dag_child"):
+ op = ExternalTaskSensor(
+ task_id="test_external_task_sensor_check",
+ external_dag_id="test_dag_parent",
+ external_task_id="test_task",
+ execution_date_fn=execution_date_fn,
+ allowed_states=["success"],
+ )
+
+ self.context["ti"].get_ti_count.return_value = 1
+ op.execute(context=self.context)
+
+ expected_date = DEFAULT_DATE + timedelta(hours=1)
+ self.context["ti"].get_ti_count.assert_called_once_with(
+ dag_id="test_dag_parent",
+ logical_dates=[expected_date],
+ states=["success"],
+ task_ids=["test_task"],
+ )
+
+ @pytest.mark.execution_timeout(10)
+ def test_external_task_sensor_execution_delta(self, dag_maker):
+ with dag_maker("test_dag_child"):
+ op = ExternalTaskSensor(
+ task_id="test_external_task_sensor_check",
+ external_dag_id="test_dag_parent",
+ external_task_id="test_task",
+ execution_delta=timedelta(hours=1),
+ allowed_states=["success"],
+ )
+
+ self.context["ti"].get_ti_count.return_value = 1
+ op.execute(context=self.context)
+
+ expected_date = DEFAULT_DATE - timedelta(hours=1)
+ self.context["ti"].get_ti_count.assert_called_once_with(
+ dag_id="test_dag_parent",
+ logical_dates=[expected_date],
+ states=["success"],
+ task_ids=["test_task"],
+ )
+
+ @pytest.mark.execution_timeout(10)
+ def test_external_task_sensor_duplicate_task_ids(self, dag_maker):
+ with dag_maker("test_dag_child"):
+ op = ExternalTaskSensor(
+ task_id="test_external_task_sensor_check",
+ external_dag_id="test_dag_parent",
+ external_task_ids=["task1", "task1"],
+ allowed_states=["success"],
+ )
+
+ with pytest.raises(ValueError, match="Duplicate task_ids passed in
external_task_ids parameter"):
+ op.execute(context=self.context)
+
+ @pytest.mark.execution_timeout(10)
+ def test_external_task_sensor_deferrable(self, dag_maker):
+ with dag_maker("test_dag_child"):
+ op = ExternalTaskSensor(
+ task_id="test_external_task_sensor_check",
+ external_dag_id="test_dag_parent",
+ external_task_id="test_task",
+ deferrable=True,
+ allowed_states=["success"],
+ )
+
+ with pytest.raises(TaskDeferred) as exc:
+ op.execute(context=self.context)
+
+ assert isinstance(exc.value.trigger, WorkflowTrigger)
+ assert exc.value.trigger.external_dag_id == "test_dag_parent"
+ assert exc.value.trigger.external_task_ids == ["test_task"]
+
+ @pytest.mark.execution_timeout(10)
+ def test_external_task_sensor_only_dag_id(self, dag_maker):
+ """Test that the sensor works correctly when only external_dag_id is
provided."""
+ with dag_maker("test_dag_child"):
+ op = ExternalTaskSensor(
+ task_id="test_external_task_sensor_check",
+ external_dag_id="test_dag_parent",
+ allowed_states=["success"],
+ )
+
+ self.context["ti"].get_dr_count = mock.MagicMock(return_value=1)
+
+ op.execute(context=self.context)
+
+ self.context["ti"].get_dr_count.assert_called_once_with(
+ dag_id="test_dag_parent",
+ logical_dates=[DEFAULT_DATE],
+ states=["success"],
+ )
+
+ @pytest.mark.execution_timeout(10)
+ def test_external_task_sensor_task_group_failed_states(self, dag_maker):
+ with dag_maker("test_dag_child"):
+ op = ExternalTaskSensor(
+ task_id="test_external_task_sensor_check",
+ external_dag_id="test_dag_parent",
+ external_task_group_id="test_group",
+ failed_states=[State.FAILED],
+ )
+
+ self.context["ti"].get_ti_count.return_value = 1
+
+ with pytest.raises(AirflowException):
+ op.execute(context=self.context)
+
+ self.context["ti"].get_ti_count.assert_called_once_with(
+ dag_id="test_dag_parent",
+ logical_dates=[DEFAULT_DATE],
+ states=[State.FAILED],
+ task_group_id="test_group",
+ )
+
+
class TestExternalTaskAsyncSensor:
TASK_ID = "external_task_sensor_check"
EXTERNAL_DAG_ID = "child_dag" # DAG the external task sensor is waiting on
@@ -1050,14 +1307,7 @@ class TestExternalTaskAsyncSensor:
mock_log_info.assert_called_with("External tasks %s has executed
successfully.", [EXTERNAL_TASK_ID])
-def test_external_task_sensor_check_zipped_dag_existence(dag_zip_maker):
- with dag_zip_maker("test_external_task_sensor_check_existense.py") as
dagbag:
- with create_session() as session:
- dag = dagbag.dags["test_external_task_sensor_check_existence"]
- op = dag.tasks[0]
- op._check_for_existence(session)
-
-
[email protected](not AIRFLOW_V_3_0_PLUS, reason="Needs Flask app context
fixture for AF 2")
@pytest.mark.parametrize(
argnames=["external_dag_id", "external_task_id",
"expected_external_dag_id", "expected_external_task_id"],
argvalues=[
@@ -1166,7 +1416,10 @@ def dag_bag_ext():
task_a_3 >> task_b_3
for dag in [dag_0, dag_1, dag_2, dag_3]:
- dag_bag.bag_dag(dag=dag)
+ if AIRFLOW_V_3_0_PLUS:
+ dag_bag.bag_dag(dag=dag)
+ else:
+ dag_bag.bag_dag(dag=dag, root_dag=dag)
yield dag_bag
@@ -1215,7 +1468,10 @@ def dag_bag_parent_child():
)
for dag in [dag_0, dag_1]:
- dag_bag.bag_dag(dag=dag)
+ if AIRFLOW_V_3_0_PLUS:
+ dag_bag.bag_dag(dag=dag)
+ else:
+ dag_bag.bag_dag(dag=dag, root_dag=dag)
yield dag_bag
@@ -1237,22 +1493,37 @@ def run_tasks(
for dag in dag_bag.dags.values():
data_interval = DataInterval(coerce_datetime(logical_date),
coerce_datetime(logical_date))
- runs[dag.dag_id] = dagrun = dag.create_dagrun(
- run_id=dag.timetable.generate_run_id(
- run_type=DagRunType.MANUAL,
+ if AIRFLOW_V_3_0_PLUS:
+ runs[dag.dag_id] = dagrun = dag.create_dagrun(
+ run_id=dag.timetable.generate_run_id(
+ run_type=DagRunType.MANUAL,
+ run_after=logical_date,
+ data_interval=data_interval,
+ ),
+ logical_date=logical_date,
+ data_interval=data_interval,
run_after=logical_date,
+ run_type=DagRunType.MANUAL,
+ triggered_by=DagRunTriggeredByType.TEST,
+ dag_version=None,
+ state=DagRunState.RUNNING,
+ start_date=logical_date,
+ session=session,
+ )
+ else:
+ runs[dag.dag_id] = dagrun = dag.create_dagrun( # type:
ignore[call-arg]
+ run_id=dag.timetable.generate_run_id( # type: ignore[call-arg]
+ run_type=DagRunType.MANUAL,
+ logical_date=logical_date,
+ data_interval=data_interval,
+ ),
+ execution_date=logical_date,
data_interval=data_interval,
- ),
- logical_date=logical_date,
- data_interval=data_interval,
- run_after=logical_date,
- run_type=DagRunType.MANUAL,
- triggered_by=DagRunTriggeredByType.TEST,
- dag_version=None,
- state=DagRunState.RUNNING,
- start_date=logical_date,
- session=session,
- )
+ run_type=DagRunType.MANUAL,
+ state=DagRunState.RUNNING,
+ start_date=logical_date,
+ session=session,
+ )
# we use sorting by task_id here because for the test DAG structure of
ours
# this is equivalent to topological sort. It would not work in general
case
# but it works for our case because we specifically constructed test
DAGS
@@ -1290,7 +1561,7 @@ def clear_tasks(
"""
Clear the task and its downstream tasks recursively for the dag in the
given dagbag.
"""
- partial: DAG = dag.partial_subset(task_ids=[task.task_id],
include_downstream=True)
+ partial: DAG = dag.partial_subset(task_ids_or_regex=[task.task_id],
include_downstream=True)
return partial.clear(
start_date=start_date,
end_date=end_date,
@@ -1300,6 +1571,7 @@ def clear_tasks(
)
[email protected](AIRFLOW_V_3_0_PLUS, reason="Different test for 3.0+")
def test_external_task_marker_transitive(dag_bag_ext):
"""
Test clearing tasks across DAGs.
@@ -1314,6 +1586,7 @@ def test_external_task_marker_transitive(dag_bag_ext):
assert_ti_state_equal(ti_b_3, State.NONE)
[email protected](AIRFLOW_V_3_0_PLUS, reason="Different test for 3.0+")
@provide_session
def test_external_task_marker_clear_activate(dag_bag_parent_child, session):
"""
@@ -1326,22 +1599,9 @@ def
test_external_task_marker_clear_activate(dag_bag_parent_child, session):
run_tasks(dag_bag, logical_date=day_1)
run_tasks(dag_bag, logical_date=day_2)
- from sqlalchemy import select
-
- run_ids = []
# Assert that dagruns of all the affected dags are set to SUCCESS before
tasks are cleared.
- for dag, logical_date in itertools.product(dag_bag.dags.values(), [day_1,
day_2]):
- run_id = (
- select(DagRun.run_id)
- .where(DagRun.logical_date == logical_date)
- .order_by(DagRun.id.desc())
- .limit(1)
- )
- run_ids.append(run_id)
- dagrun = dag.get_dagrun(
- run_id=run_id,
- session=session,
- )
+ for dag, execution_date in itertools.product(dag_bag.dags.values(),
[day_1, day_2]):
+ dagrun = dag.get_dagrun(execution_date=execution_date, session=session)
dagrun.set_state(State.SUCCESS)
session.flush()
@@ -1351,10 +1611,10 @@ def
test_external_task_marker_clear_activate(dag_bag_parent_child, session):
# Assert that dagruns of all the affected dags are set to QUEUED after
tasks are cleared.
# Unaffected dagruns should be left as SUCCESS.
- dagrun_0_1 = dag_bag.get_dag("parent_dag_0").get_dagrun(run_id=run_ids[0],
session=session)
- dagrun_0_2 = dag_bag.get_dag("parent_dag_0").get_dagrun(run_id=run_ids[1],
session=session)
- dagrun_1_1 = dag_bag.get_dag("child_dag_1").get_dagrun(run_id=run_ids[2],
session=session)
- dagrun_1_2 = dag_bag.get_dag("child_dag_1").get_dagrun(run_id=run_ids[3],
session=session)
+ dagrun_0_1 =
dag_bag.get_dag("parent_dag_0").get_dagrun(execution_date=day_1,
session=session)
+ dagrun_0_2 =
dag_bag.get_dag("parent_dag_0").get_dagrun(execution_date=day_2,
session=session)
+ dagrun_1_1 =
dag_bag.get_dag("child_dag_1").get_dagrun(execution_date=day_1, session=session)
+ dagrun_1_2 =
dag_bag.get_dag("child_dag_1").get_dagrun(execution_date=day_2, session=session)
assert dagrun_0_1.state == State.QUEUED
assert dagrun_0_2.state == State.QUEUED
@@ -1362,6 +1622,7 @@ def
test_external_task_marker_clear_activate(dag_bag_parent_child, session):
assert dagrun_1_2.state == State.SUCCESS
[email protected](AIRFLOW_V_3_0_PLUS, reason="Different test for 3.0+")
def test_external_task_marker_future(dag_bag_ext):
"""
Test clearing tasks with no end_date. This is the case when users clear
tasks with
@@ -1386,6 +1647,7 @@ def test_external_task_marker_future(dag_bag_ext):
assert_ti_state_equal(ti_b_3_date_1, State.NONE)
[email protected](AIRFLOW_V_3_0_PLUS, reason="Different test for 3.0+")
def test_external_task_marker_exception(dag_bag_ext):
"""
Clearing across multiple DAGs should raise AirflowException if more levels
are being cleared
@@ -1463,13 +1725,17 @@ def dag_bag_cyclic():
task_a >> task_b
for dag in dags:
- dag_bag.bag_dag(dag=dag)
+ if AIRFLOW_V_3_0_PLUS:
+ dag_bag.bag_dag(dag=dag)
+ else:
+ dag_bag.bag_dag(dag=dag, root_dag=dag) # type:
ignore[call-arg]
return dag_bag
return _factory
[email protected](AIRFLOW_V_3_0_PLUS, reason="Different test for 3.0+")
def test_external_task_marker_cyclic_deep(dag_bag_cyclic):
"""
Tests clearing across multiple DAGs that have cyclic dependencies.
AirflowException should be
@@ -1483,6 +1749,7 @@ def test_external_task_marker_cyclic_deep(dag_bag_cyclic):
clear_tasks(dag_bag, dag_0, task_a_0)
[email protected](AIRFLOW_V_3_0_PLUS, reason="Different test for 3.0+")
def test_external_task_marker_cyclic_shallow(dag_bag_cyclic):
"""
Tests clearing across multiple DAGs that have cyclic dependencies shallower
@@ -1513,8 +1780,13 @@ def dag_bag_multiple():
dag_bag = DagBag(dag_folder=DEV_NULL, include_examples=False)
daily_dag = DAG("daily_dag", start_date=DEFAULT_DATE, schedule="@daily")
agg_dag = DAG("agg_dag", start_date=DEFAULT_DATE, schedule="@daily")
- dag_bag.bag_dag(dag=daily_dag)
- dag_bag.bag_dag(dag=agg_dag)
+
+ if AIRFLOW_V_3_0_PLUS:
+ dag_bag.bag_dag(dag=daily_dag)
+ dag_bag.bag_dag(dag=agg_dag)
+ else:
+ dag_bag.bag_dag(dag=daily_dag, root_dag=daily_dag)
+ dag_bag.bag_dag(dag=agg_dag, root_dag=agg_dag)
daily_task = EmptyOperator(task_id="daily_tas", dag=daily_dag)
@@ -1584,7 +1856,10 @@ def dag_bag_head_tail():
)
head >> body >> tail
- dag_bag.bag_dag(dag=dag)
+ if AIRFLOW_V_3_0_PLUS:
+ dag_bag.bag_dag(dag=dag)
+ else:
+ dag_bag.bag_dag(dag=dag, root_dag=dag)
return dag_bag
@@ -1600,10 +1875,13 @@ def
test_clear_overlapping_external_task_marker(dag_bag_head_tail, session):
dag_id=dag.dag_id,
start_date=logical_date,
state=DagRunState.SUCCESS,
- logical_date=logical_date,
run_type=DagRunType.MANUAL,
run_id=f"test_{delta}",
)
+ if AIRFLOW_V_3_0_PLUS:
+ dagrun.logical_date = logical_date
+ else:
+ dagrun.execution_date = logical_date
session.add(dagrun)
for task in dag.tasks:
ti = TaskInstance(task=task)
@@ -1625,10 +1903,13 @@ def
test_clear_overlapping_external_task_marker_with_end_date(dag_bag_head_tail,
dag_id=dag.dag_id,
start_date=logical_date,
state=DagRunState.SUCCESS,
- logical_date=logical_date,
run_type=DagRunType.MANUAL,
run_id=f"test_{delta}",
)
+ if AIRFLOW_V_3_0_PLUS:
+ dagrun.logical_date = logical_date
+ else:
+ dagrun.execution_date = logical_date
session.add(dagrun)
for task in dag.tasks:
ti = TaskInstance(task=task)
@@ -1689,7 +1970,10 @@ def dag_bag_head_tail_mapped_tasks():
)
head >> body >> tail
- dag_bag.bag_dag(dag=dag)
+ if AIRFLOW_V_3_0_PLUS:
+ dag_bag.bag_dag(dag=dag)
+ else:
+ dag_bag.bag_dag(dag=dag, root_dag=dag)
return dag_bag
@@ -1705,10 +1989,13 @@ def
test_clear_overlapping_external_task_marker_mapped_tasks(dag_bag_head_tail_m
dag_id=dag.dag_id,
start_date=logical_date,
state=DagRunState.SUCCESS,
- logical_date=logical_date,
run_type=DagRunType.MANUAL,
run_id=f"test_{delta}",
)
+ if AIRFLOW_V_3_0_PLUS:
+ dagrun.logical_date = logical_date
+ else:
+ dagrun.execution_date = logical_date
session.add(dagrun)
for task in dag.tasks:
if task.task_id == "dummy_task":
@@ -1721,12 +2008,19 @@ def
test_clear_overlapping_external_task_marker_mapped_tasks(dag_bag_head_tail_m
ti.state = TaskInstanceState.SUCCESS
dagrun.task_instances.append(ti)
session.flush()
+ if AIRFLOW_V_3_0_PLUS:
+ dag = dag.partial_subset(
+ task_ids=["head"],
+ include_downstream=True,
+ include_upstream=False,
+ )
+ else:
+ dag = dag.partial_subset(
+ task_ids_or_regex=["head"],
+ include_downstream=True,
+ include_upstream=False,
+ )
- dag = dag.partial_subset(
- task_ids=["head"],
- include_downstream=True,
- include_upstream=False,
- )
task_ids = list(dag.task_dict)
assert (
dag.clear(
diff --git a/task-sdk/src/airflow/sdk/api/client.py
b/task-sdk/src/airflow/sdk/api/client.py
index cfffd2b3823..a2fff3a335b 100644
--- a/task-sdk/src/airflow/sdk/api/client.py
+++ b/task-sdk/src/airflow/sdk/api/client.py
@@ -59,10 +59,12 @@ from airflow.sdk.api.datamodels._generated import (
)
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import (
+ DRCount,
ErrorResponse,
OKResponse,
SkipDownstreamTasks,
TaskRescheduleStartDate,
+ TICount,
)
from airflow.utils.net import get_hostname
from airflow.utils.platform import getuser
@@ -200,6 +202,31 @@ class TaskInstanceOperations:
resp = self.client.get(f"task-reschedules/{id}/start_date",
params={"try_number": try_number})
return TaskRescheduleStartDate.model_construct(start_date=resp.json())
+ def get_count(
+ self,
+ dag_id: str,
+ task_ids: list[str] | None = None,
+ task_group_id: str | None = None,
+ logical_dates: list[datetime] | None = None,
+ run_ids: list[str] | None = None,
+ states: list[str] | None = None,
+ ) -> TICount:
+ """Get count of task instances matching the given criteria."""
+ params = {
+ "dag_id": dag_id,
+ "task_ids": task_ids,
+ "task_group_id": task_group_id,
+ "logical_dates": [d.isoformat() for d in logical_dates] if
logical_dates is not None else None,
+ "run_ids": run_ids,
+ "states": states,
+ }
+
+ # Remove None values from params
+ params = {k: v for k, v in params.items() if v is not None}
+
+ resp = self.client.get("task-instances/count", params=params)
+ return TICount(count=resp.json())
+
class ConnectionOperations:
__slots__ = ("client",)
@@ -452,6 +479,27 @@ class DagRunOperations:
resp = self.client.get(f"dag-runs/{dag_id}/{run_id}/state")
return DagRunStateResponse.model_validate_json(resp.read())
+ def get_count(
+ self,
+ dag_id: str,
+ logical_dates: list[datetime] | None = None,
+ run_ids: list[str] | None = None,
+ states: list[str] | None = None,
+ ) -> DRCount:
+ """Get count of DAG runs matching the given criteria."""
+ params = {
+ "dag_id": dag_id,
+ "logical_dates": [d.isoformat() for d in logical_dates] if
logical_dates is not None else None,
+ "run_ids": run_ids,
+ "states": states,
+ }
+
+ # Remove None values from params
+ params = {k: v for k, v in params.items() if v is not None}
+
+ resp = self.client.get("dag-runs/count", params=params)
+ return DRCount(count=resp.json())
+
class BearerAuth(httpx.Auth):
def __init__(self, token: str):
diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index d579aa8fb48..76c858f3ac1 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -222,6 +222,20 @@ class TaskRescheduleStartDate(BaseModel):
type: Literal["TaskRescheduleStartDate"] = "TaskRescheduleStartDate"
+class TICount(BaseModel):
+ """Response containing count of Task Instances matching certain filters."""
+
+ count: int
+ type: Literal["TICount"] = "TICount"
+
+
+class DRCount(BaseModel):
+ """Response containing count of DAG Runs matching certain filters."""
+
+ count: int
+ type: Literal["DRCount"] = "DRCount"
+
+
class ErrorResponse(BaseModel):
error: ErrorType = ErrorType.GENERIC_ERROR
detail: dict | None = None
@@ -239,10 +253,12 @@ ToTask = Annotated[
AssetEventsResult,
ConnectionResult,
DagRunStateResult,
+ DRCount,
ErrorResponse,
PrevSuccessfulDagRunResult,
StartupDetails,
TaskRescheduleStartDate,
+ TICount,
VariableResult,
XComResult,
XComCountResponse,
@@ -445,30 +461,50 @@ class GetTaskRescheduleStartDate(BaseModel):
type: Literal["GetTaskRescheduleStartDate"] = "GetTaskRescheduleStartDate"
+class GetTICount(BaseModel):
+ dag_id: str
+ task_ids: list[str] | None = None
+ task_group_id: str | None = None
+ logical_dates: list[AwareDatetime] | None = None
+ run_ids: list[str] | None = None
+ states: list[str] | None = None
+ type: Literal["GetTICount"] = "GetTICount"
+
+
+class GetDRCount(BaseModel):
+ dag_id: str
+ logical_dates: list[AwareDatetime] | None = None
+ run_ids: list[str] | None = None
+ states: list[str] | None = None
+ type: Literal["GetDRCount"] = "GetDRCount"
+
+
ToSupervisor = Annotated[
Union[
- SucceedTask,
DeferTask,
+ DeleteXCom,
GetAssetByName,
GetAssetByUri,
GetAssetEventByAsset,
GetAssetEventByAssetAlias,
GetConnection,
GetDagRunState,
+ GetDRCount,
GetPrevSuccessfulDagRun,
GetTaskRescheduleStartDate,
+ GetTICount,
GetVariable,
GetXCom,
GetXComCount,
PutVariable,
RescheduleTask,
RetryTask,
- SkipDownstreamTasks,
SetRenderedFields,
SetXCom,
+ SkipDownstreamTasks,
+ SucceedTask,
TaskState,
TriggerDagRun,
- DeleteXCom,
],
Field(discriminator="type"),
]
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index a353ffce23c..21ae3cd4cae 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -77,8 +77,10 @@ from airflow.sdk.execution_time.comms import (
GetAssetEventByAssetAlias,
GetConnection,
GetDagRunState,
+ GetDRCount,
GetPrevSuccessfulDagRun,
GetTaskRescheduleStartDate,
+ GetTICount,
GetVariable,
GetXCom,
GetXComCount,
@@ -988,6 +990,24 @@ class ActivitySubprocess(WatchedSubprocess):
elif isinstance(msg, GetTaskRescheduleStartDate):
tr_resp =
self.client.task_instances.get_reschedule_start_date(msg.ti_id, msg.try_number)
resp = tr_resp.model_dump_json().encode()
+ elif isinstance(msg, GetTICount):
+ ti_count = self.client.task_instances.get_count(
+ dag_id=msg.dag_id,
+ task_ids=msg.task_ids,
+ task_group_id=msg.task_group_id,
+ logical_dates=msg.logical_dates,
+ run_ids=msg.run_ids,
+ states=msg.states,
+ )
+ resp = ti_count.model_dump_json().encode()
+ elif isinstance(msg, GetDRCount):
+ dr_count = self.client.dag_runs.get_count(
+ dag_id=msg.dag_id,
+ logical_dates=msg.logical_dates,
+ run_ids=msg.run_ids,
+ states=msg.states,
+ )
+ resp = dr_count.model_dump_json().encode()
else:
log.error("Unhandled request", msg=msg)
return
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index d5d738f5254..d89fbeb0aa9 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -58,9 +58,12 @@ from airflow.sdk.execution_time.callback_runner import
create_executable_runner
from airflow.sdk.execution_time.comms import (
DagRunStateResult,
DeferTask,
+ DRCount,
ErrorResponse,
GetDagRunState,
+ GetDRCount,
GetTaskRescheduleStartDate,
+ GetTICount,
RescheduleTask,
RetryTask,
SetRenderedFields,
@@ -69,6 +72,7 @@ from airflow.sdk.execution_time.comms import (
SucceedTask,
TaskRescheduleStartDate,
TaskState,
+ TICount,
ToSupervisor,
ToTask,
TriggerDagRun,
@@ -400,6 +404,62 @@ class RuntimeTaskInstance(TaskInstance):
return response.start_date
+ @staticmethod
+ def get_ti_count(
+ dag_id: str,
+ task_ids: list[str] | None = None,
+ task_group_id: str | None = None,
+ logical_dates: list[datetime] | None = None,
+ run_ids: list[str] | None = None,
+ states: list[str] | None = None,
+ ) -> int:
+ """Return the number of task instances matching the given criteria."""
+ log = structlog.get_logger(logger_name="task")
+
+ SUPERVISOR_COMMS.send_request(
+ log=log,
+ msg=GetTICount(
+ dag_id=dag_id,
+ task_ids=task_ids,
+ task_group_id=task_group_id,
+ logical_dates=logical_dates,
+ run_ids=run_ids,
+ states=states,
+ ),
+ )
+ response = SUPERVISOR_COMMS.get_message()
+
+ if TYPE_CHECKING:
+ assert isinstance(response, TICount)
+
+ return response.count
+
+ @staticmethod
+ def get_dr_count(
+ dag_id: str,
+ logical_dates: list[datetime] | None = None,
+ run_ids: list[str] | None = None,
+ states: list[str] | None = None,
+ ) -> int:
+ """Return the number of DAG runs matching the given criteria."""
+ log = structlog.get_logger(logger_name="task")
+
+ SUPERVISOR_COMMS.send_request(
+ log=log,
+ msg=GetDRCount(
+ dag_id=dag_id,
+ logical_dates=logical_dates,
+ run_ids=run_ids,
+ states=states,
+ ),
+ )
+ response = SUPERVISOR_COMMS.get_message()
+
+ if TYPE_CHECKING:
+ assert isinstance(response, DRCount)
+
+ return response.count
+
def _xcom_push(ti: RuntimeTaskInstance, key: str, value: Any, mapped_length:
int | None = None) -> None:
"""Push a XCom through XCom.set, which pushes to XCom Backend if
configured."""
diff --git a/task-sdk/src/airflow/sdk/types.py
b/task-sdk/src/airflow/sdk/types.py
index 6760ea51959..d9544589cf9 100644
--- a/task-sdk/src/airflow/sdk/types.py
+++ b/task-sdk/src/airflow/sdk/types.py
@@ -87,6 +87,24 @@ class RuntimeTaskInstanceProtocol(Protocol):
def get_first_reschedule_date(self, first_try_number) -> AwareDatetime |
None: ...
+ @staticmethod
+ def get_ti_count(
+ dag_id: str,
+ task_ids: list[str] | None = None,
+ task_group_id: str | None = None,
+ logical_dates: list[AwareDatetime] | None = None,
+ run_ids: list[str] | None = None,
+ states: list[str] | None = None,
+ ) -> int: ...
+
+ @staticmethod
+ def get_dr_count(
+ dag_id: str,
+ logical_dates: list[AwareDatetime] | None = None,
+ run_ids: list[str] | None = None,
+ states: list[str] | None = None,
+ ) -> int: ...
+
class OutletEventAccessorProtocol(Protocol, attrs.AttrsInstance):
"""Protocol for managing access to a specific outlet event accessor."""
diff --git a/task-sdk/tests/task_sdk/api/test_client.py
b/task-sdk/tests/task_sdk/api/test_client.py
index f0092f8aea9..3339a6cee4c 100644
--- a/task-sdk/tests/task_sdk/api/test_client.py
+++ b/task-sdk/tests/task_sdk/api/test_client.py
@@ -405,6 +405,48 @@ class TestTaskInstanceOperations:
assert result == {"ok": True}
+ def test_get_count_basic(self):
+ """Test basic get_count functionality with just dag_id."""
+
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ assert request.url.path == "/task-instances/count"
+ assert request.url.params.get("dag_id") == "test_dag"
+ return httpx.Response(200, json=5)
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.task_instances.get_count(dag_id="test_dag")
+ assert result.count == 5
+
+ def test_get_count_with_all_params(self):
+ """Test get_count with all optional parameters."""
+
+ logical_dates_str = ["2024-01-01T00:00:00+00:00",
"2024-01-02T00:00:00+00:00"]
+ logical_dates = [timezone.parse(d) for d in logical_dates_str]
+ task_ids = ["task1", "task2"]
+ states = ["success", "failed"]
+
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ assert request.url.path == "/task-instances/count"
+ assert request.method == "GET"
+ params = request.url.params
+ assert params["dag_id"] == "test_dag"
+ assert params.get_list("task_ids") == task_ids
+ assert params["task_group_id"] == "group1"
+ assert params.get_list("logical_dates") == logical_dates_str
+ assert params.get_list("run_ids") == []
+ assert params.get_list("states") == states
+ return httpx.Response(200, json=10)
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.task_instances.get_count(
+ dag_id="test_dag",
+ task_ids=task_ids,
+ task_group_id="group1",
+ logical_dates=logical_dates,
+ states=states,
+ )
+ assert result.count == 10
+
class TestVariableOperations:
"""
@@ -904,6 +946,58 @@ class TestDagRunOperations:
assert result == DagRunStateResponse(state=DagRunState.RUNNING)
+ def test_get_count_basic(self):
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ if request.url.path == "/dag-runs/count":
+ assert request.url.params["dag_id"] == "test_dag"
+ return httpx.Response(status_code=200, json=1)
+ return httpx.Response(status_code=422)
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.dag_runs.get_count(dag_id="test_dag")
+ assert result.count == 1
+
+ def test_get_count_with_states(self):
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ if request.url.path == "/dag-runs/count":
+ assert request.url.params["dag_id"] == "test_dag"
+ assert request.url.params.get_list("states") == ["success",
"failed"]
+ return httpx.Response(status_code=200, json=2)
+ return httpx.Response(status_code=422)
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.dag_runs.get_count(dag_id="test_dag",
states=["success", "failed"])
+ assert result.count == 2
+
+ def test_get_count_with_logical_dates(self):
+ logical_dates = [timezone.datetime(2025, 1, 1),
timezone.datetime(2025, 1, 2)]
+ logical_dates_str = [d.isoformat() for d in logical_dates]
+
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ if request.url.path == "/dag-runs/count":
+ assert request.url.params["dag_id"] == "test_dag"
+ assert request.url.params.get_list("logical_dates") ==
logical_dates_str
+ return httpx.Response(status_code=200, json=2)
+ return httpx.Response(status_code=422)
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.dag_runs.get_count(
+ dag_id="test_dag", logical_dates=[timezone.datetime(2025, 1, 1),
timezone.datetime(2025, 1, 2)]
+ )
+ assert result.count == 2
+
+ def test_get_count_with_run_ids(self):
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ if request.url.path == "/dag-runs/count":
+ assert request.url.params["dag_id"] == "test_dag"
+ assert request.url.params.get_list("run_ids") == ["run1",
"run2"]
+ return httpx.Response(status_code=200, json=2)
+ return httpx.Response(status_code=422)
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.dag_runs.get_count(dag_id="test_dag", run_ids=["run1",
"run2"])
+ assert result.count == 2
+
class TestTaskRescheduleOperations:
def test_get_start_date(self):
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index 64906726219..5cae5232e64 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -60,6 +60,7 @@ from airflow.sdk.execution_time.comms import (
DagRunStateResult,
DeferTask,
DeleteXCom,
+ DRCount,
ErrorResponse,
GetAssetByName,
GetAssetByUri,
@@ -67,8 +68,10 @@ from airflow.sdk.execution_time.comms import (
GetAssetEventByAssetAlias,
GetConnection,
GetDagRunState,
+ GetDRCount,
GetPrevSuccessfulDagRun,
GetTaskRescheduleStartDate,
+ GetTICount,
GetVariable,
GetXCom,
OKResponse,
@@ -80,6 +83,7 @@ from airflow.sdk.execution_time.comms import (
SucceedTask,
TaskRescheduleStartDate,
TaskState,
+ TICount,
TriggerDagRun,
VariableResult,
XComResult,
@@ -1350,6 +1354,36 @@ class TestHandleRequest:
TaskRescheduleStartDate(start_date=timezone.parse("2024-10-31T12:00:00Z")),
id="get_task_reschedule_start_date",
),
+ pytest.param(
+ GetTICount(dag_id="test_dag", task_ids=["task1", "task2"]),
+ b'{"count":2,"type":"TICount"}\n',
+ "task_instances.get_count",
+ (),
+ {
+ "dag_id": "test_dag",
+ "logical_dates": None,
+ "run_ids": None,
+ "states": None,
+ "task_group_id": None,
+ "task_ids": ["task1", "task2"],
+ },
+ TICount(count=2),
+ id="get_ti_count",
+ ),
+ pytest.param(
+ GetDRCount(dag_id="test_dag", states=["success", "failed"]),
+ b'{"count":2,"type":"DRCount"}\n',
+ "dag_runs.get_count",
+ (),
+ {
+ "dag_id": "test_dag",
+ "logical_dates": None,
+ "run_ids": None,
+ "states": ["success", "failed"],
+ },
+ DRCount(count=2),
+ id="get_dr_count",
+ ),
],
)
def test_handle_requests(
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 150b691fc17..f7ac25f6ae8 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -67,9 +67,12 @@ from airflow.sdk.execution_time.comms import (
ConnectionResult,
DagRunStateResult,
DeferTask,
+ DRCount,
ErrorResponse,
GetConnection,
GetDagRunState,
+ GetDRCount,
+ GetTICount,
GetVariable,
GetXCom,
OKResponse,
@@ -81,6 +84,7 @@ from airflow.sdk.execution_time.comms import (
SucceedTask,
TaskRescheduleStartDate,
TaskState,
+ TICount,
TriggerDagRun,
VariableResult,
XComResult,
@@ -1396,6 +1400,54 @@ class TestRuntimeTaskInstance:
context = runtime_ti.get_template_context()
assert runtime_ti.get_first_reschedule_date(context=context) ==
expected_date
+ def test_get_ti_count(self, mock_supervisor_comms):
+ """Test that get_ti_count sends the correct request and returns the
count."""
+ mock_supervisor_comms.get_message.return_value = TICount(count=2)
+
+ count = RuntimeTaskInstance.get_ti_count(
+ dag_id="test_dag",
+ task_ids=["task1", "task2"],
+ task_group_id="group1",
+ logical_dates=[timezone.datetime(2024, 1, 1)],
+ run_ids=["run1"],
+ states=["success", "failed"],
+ )
+
+ mock_supervisor_comms.send_request.assert_called_once_with(
+ log=mock.ANY,
+ msg=GetTICount(
+ dag_id="test_dag",
+ task_ids=["task1", "task2"],
+ task_group_id="group1",
+ logical_dates=[timezone.datetime(2024, 1, 1)],
+ run_ids=["run1"],
+ states=["success", "failed"],
+ ),
+ )
+ assert count == 2
+
+ def test_get_dr_count(self, mock_supervisor_comms):
+ """Test that get_dr_count sends the correct request and returns the
count."""
+ mock_supervisor_comms.get_message.return_value = DRCount(count=2)
+
+ count = RuntimeTaskInstance.get_dr_count(
+ dag_id="test_dag",
+ logical_dates=[timezone.datetime(2024, 1, 1)],
+ run_ids=["run1"],
+ states=["success", "failed"],
+ )
+
+ mock_supervisor_comms.send_request.assert_called_once_with(
+ log=mock.ANY,
+ msg=GetDRCount(
+ dag_id="test_dag",
+ logical_dates=[timezone.datetime(2024, 1, 1)],
+ run_ids=["run1"],
+ states=["success", "failed"],
+ ),
+ )
+ assert count == 2
+
class TestXComAfterTaskExecution:
@pytest.mark.parametrize(