This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6775bf7bae1 Make `ExternalTaskSensor` work with Task SDK (#48651)
6775bf7bae1 is described below

commit 6775bf7bae13f4291e18d4118179c14e4444de0d
Author: Kaxil Naik <[email protected]>
AuthorDate: Thu Apr 3 13:57:49 2025 +0530

    Make `ExternalTaskSensor` work with Task SDK (#48651)
    
    closes https://github.com/apache/airflow/issues/47447
    
    closes  https://github.com/apache/airflow/issues/47948
---
 .../api_fastapi/execution_api/routes/dag_runs.py   |  30 +-
 .../execution_api/routes/task_instances.py         | 108 ++++-
 airflow-core/tests/conftest.py                     |  11 -
 .../execution_api/versions/head/test_dag_runs.py   |  98 ++++-
 .../versions/head/test_task_instances.py           | 165 +++++++
 devel-common/src/tests_common/pytest_plugin.py     |  11 +
 .../providers/standard/sensors/external_task.py    | 128 ++++--
 .../providers/standard/utils/sensor_helper.py      |   9 +-
 .../standard}/sensors/test_external_task_sensor.py | 490 ++++++++++++++++-----
 task-sdk/src/airflow/sdk/api/client.py             |  48 ++
 task-sdk/src/airflow/sdk/execution_time/comms.py   |  42 +-
 .../src/airflow/sdk/execution_time/supervisor.py   |  20 +
 .../src/airflow/sdk/execution_time/task_runner.py  |  60 +++
 task-sdk/src/airflow/sdk/types.py                  |  18 +
 task-sdk/tests/task_sdk/api/test_client.py         |  94 ++++
 .../task_sdk/execution_time/test_supervisor.py     |  34 ++
 .../task_sdk/execution_time/test_task_runner.py    |  52 +++
 17 files changed, 1261 insertions(+), 157 deletions(-)

diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py
index 3a680c1ef8c..a4398601138 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py
@@ -18,13 +18,15 @@
 from __future__ import annotations
 
 import logging
+from typing import Annotated
 
-from fastapi import HTTPException, status
-from sqlalchemy import select
+from fastapi import HTTPException, Query, status
+from sqlalchemy import func, select
 
 from airflow.api.common.trigger_dag import trigger_dag
 from airflow.api_fastapi.common.db.common import SessionDep
 from airflow.api_fastapi.common.router import AirflowRouter
+from airflow.api_fastapi.common.types import UtcDateTime
 from airflow.api_fastapi.execution_api.datamodels.dagrun import 
DagRunStateResponse, TriggerDAGRunPayload
 from airflow.exceptions import DagRunAlreadyExists
 from airflow.models.dag import DagModel
@@ -150,3 +152,27 @@ def get_dagrun_state(
         )
 
     return DagRunStateResponse(state=dag_run.state)
+
+
[email protected]("/count", status_code=status.HTTP_200_OK)
+def get_dr_count(
+    dag_id: str,
+    session: SessionDep,
+    logical_dates: Annotated[list[UtcDateTime] | None, Query()] = None,
+    run_ids: Annotated[list[str] | None, Query()] = None,
+    states: Annotated[list[str] | None, Query()] = None,
+) -> int:
+    """Get the count of DAG runs matching the given criteria."""
+    query = select(func.count()).select_from(DagRun).where(DagRun.dag_id == 
dag_id)
+
+    if logical_dates:
+        query = query.where(DagRun.logical_date.in_(logical_dates))
+
+    if run_ids:
+        query = query.where(DagRun.run_id.in_(run_ids))
+
+    if states:
+        query = query.where(DagRun.state.in_(states))
+
+    count = session.scalar(query)
+    return count or 0
diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index ff0ca516314..6dd0732c077 100644
--- 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++ 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -23,13 +23,14 @@ from typing import Annotated
 from uuid import UUID
 
 from cadwyn import VersionedAPIRouter
-from fastapi import Body, Depends, HTTPException, status
+from fastapi import Body, Depends, HTTPException, Query, status
 from pydantic import JsonValue
-from sqlalchemy import func, tuple_, update
+from sqlalchemy import func, or_, tuple_, update
 from sqlalchemy.exc import NoResultFound, SQLAlchemyError
 from sqlalchemy.sql import select
 
 from airflow.api_fastapi.common.db.common import SessionDep
+from airflow.api_fastapi.common.types import UtcDateTime
 from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
     PrevSuccessfulDagRunResponse,
     TIDeferredStatePayload,
@@ -45,6 +46,7 @@ from 
airflow.api_fastapi.execution_api.datamodels.taskinstance import (
     TITerminalStatePayload,
 )
 from airflow.api_fastapi.execution_api.deps import JWTBearer
+from airflow.models.dagbag import DagBag
 from airflow.models.dagrun import DagRun as DR
 from airflow.models.taskinstance import TaskInstance as TI, _update_rtif
 from airflow.models.taskreschedule import TaskReschedule
@@ -53,7 +55,9 @@ from airflow.models.xcom import XComModel
 from airflow.utils import timezone
 from airflow.utils.state import DagRunState, TaskInstanceState
 
-router = VersionedAPIRouter(
+router = VersionedAPIRouter()
+
+ti_id_router = VersionedAPIRouter(
     dependencies=[
         # This checks that the UUID in the url matches the one in the token 
for us.
         Depends(JWTBearer(path_param_name="task_instance_id")),
@@ -64,7 +68,7 @@ router = VersionedAPIRouter(
 log = logging.getLogger(__name__)
 
 
[email protected](
+@ti_id_router.patch(
     "/{task_instance_id}/run",
     status_code=status.HTTP_200_OK,
     responses={
@@ -243,7 +247,7 @@ def ti_run(
         )
 
 
[email protected](
+@ti_id_router.patch(
     "/{task_instance_id}/state",
     status_code=status.HTTP_204_NO_CONTENT,
     responses={
@@ -404,7 +408,7 @@ def ti_update_state(
         )
 
 
[email protected](
+@ti_id_router.patch(
     "/{task_instance_id}/skip-downstream",
     status_code=status.HTTP_204_NO_CONTENT,
     responses={
@@ -436,7 +440,7 @@ def ti_skip_downstream(
     log.info("TI %s updated the state of %s task(s) to skipped", ti_id_str, 
result.rowcount)
 
 
[email protected](
+@ti_id_router.put(
     "/{task_instance_id}/heartbeat",
     status_code=status.HTTP_204_NO_CONTENT,
     responses={
@@ -498,7 +502,7 @@ def ti_heartbeat(
     log.debug("Task with %s state heartbeated", previous_state)
 
 
[email protected](
+@ti_id_router.put(
     "/{task_instance_id}/rtif",
     status_code=status.HTTP_201_CREATED,
     # TODO: Add description to the operation
@@ -528,7 +532,7 @@ def ti_put_rtif(
     return {"message": "Rendered task instance fields successfully set"}
 
 
[email protected](
+@ti_id_router.get(
     "/{task_instance_id}/previous-successful-dagrun",
     status_code=status.HTTP_200_OK,
     responses={
@@ -564,8 +568,86 @@ def get_previous_successful_dagrun(
     return PrevSuccessfulDagRunResponse.model_validate(dag_run)
 
 
[email protected]_exists_in_older_versions
[email protected](
[email protected]("/count", status_code=status.HTTP_200_OK)
+def get_count(
+    dag_id: str,
+    session: SessionDep,
+    task_ids: Annotated[list[str] | None, Query()] = None,
+    task_group_id: Annotated[str | None, Query()] = None,
+    logical_dates: Annotated[list[UtcDateTime] | None, Query()] = None,
+    run_ids: Annotated[list[str] | None, Query()] = None,
+    states: Annotated[list[str] | None, Query()] = None,
+) -> int:
+    """Get the count of task instances matching the given criteria."""
+    query = select(func.count()).select_from(TI).where(TI.dag_id == dag_id)
+
+    if task_ids:
+        query = query.where(TI.task_id.in_(task_ids))
+
+    if logical_dates:
+        query = query.where(TI.logical_date.in_(logical_dates))
+
+    if run_ids:
+        query = query.where(TI.run_id.in_(run_ids))
+
+    if task_group_id:
+        # Get all tasks in the task group
+        dag = DagBag(read_dags_from_db=True).get_dag(dag_id, session)
+        if not dag:
+            raise HTTPException(
+                status.HTTP_404_NOT_FOUND,
+                detail={
+                    "reason": "not_found",
+                    "message": f"DAG {dag_id} not found",
+                },
+            )
+
+        task_group = dag.task_group_dict.get(task_group_id)
+        if not task_group:
+            raise HTTPException(
+                status.HTTP_404_NOT_FOUND,
+                detail={
+                    "reason": "not_found",
+                    "message": f"Task group {task_group_id} not found in DAG 
{dag_id}",
+                },
+            )
+
+        # First get all task instances to get the task_id, map_index pairs
+        group_tasks = session.scalars(
+            select(TI).where(
+                TI.dag_id == dag_id,
+                TI.task_id.in_(task.task_id for task in 
task_group.iter_tasks()),
+                *([TI.logical_date.in_(logical_dates)] if logical_dates else 
[]),
+                *([TI.run_id.in_(run_ids)] if run_ids else []),
+            )
+        ).all()
+
+        # Get unique (task_id, map_index) pairs
+        task_map_pairs = [(ti.task_id, ti.map_index) for ti in group_tasks]
+        if not task_map_pairs:
+            # If no task group tasks found, default to checking the task group 
ID itself
+            # This matches the behavior in _get_external_task_group_task_ids
+            task_map_pairs = [(task_group_id, -1)]
+
+        # Update query to use task_id, map_index pairs
+        query = query.where(tuple_(TI.task_id, 
TI.map_index).in_(task_map_pairs))
+
+    if states:
+        if "null" in states:
+            not_none_states = [s for s in states if s != "null"]
+            if not_none_states:
+                query = query.where(or_(TI.state.is_(None), 
TI.state.in_(not_none_states)))
+            else:
+                query = query.where(TI.state.is_(None))
+        else:
+            query = query.where(TI.state.in_(states))
+
+    count = session.scalar(query)
+    return count or 0
+
+
+@ti_id_router.only_exists_in_older_versions
+@ti_id_router.post(
     "/{task_instance_id}/runtime-checks",
     status_code=status.HTTP_204_NO_CONTENT,
     # TODO: Add description to the operation
@@ -602,3 +684,7 @@ def _is_eligible_to_retry(state: str, try_number: int, 
max_tries: int) -> bool:
     # max_tries is initialised with the retries defined at task level, we do 
not need to explicitly ask for
     # retries from the task SDK now, we can handle using max_tries
     return max_tries != 0 and try_number <= max_tries
+
+
+# This line should be at the end of the file to ensure all routes are 
registered
+router.include_router(ti_id_router)
diff --git a/airflow-core/tests/conftest.py b/airflow-core/tests/conftest.py
index c5affb469f4..c605cc7648e 100644
--- a/airflow-core/tests/conftest.py
+++ b/airflow-core/tests/conftest.py
@@ -78,17 +78,6 @@ def clear_all_logger_handlers():
     remove_all_non_pytest_log_handlers()
 
 
[email protected]
-def testing_dag_bundle():
-    from airflow.models.dagbundle import DagBundleModel
-    from airflow.utils.session import create_session
-
-    with create_session() as session:
-        if session.query(DagBundleModel).filter(DagBundleModel.name == 
"testing").count() == 0:
-            testing = DagBundleModel(name="testing")
-            session.add(testing)
-
-
 @contextmanager
 def _config_bundles(bundles: dict[str, Path | str]):
     from tests_common.test_utils.config import conf_vars
diff --git 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py
 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py
index c5624f188b5..f9f8d489d3d 100644
--- 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py
+++ 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dag_runs.py
@@ -23,7 +23,7 @@ from airflow.models import DagModel
 from airflow.models.dagrun import DagRun
 from airflow.providers.standard.operators.empty import EmptyOperator
 from airflow.utils import timezone
-from airflow.utils.state import DagRunState
+from airflow.utils.state import DagRunState, State
 
 from tests_common.test_utils.db import clear_db_runs
 
@@ -218,3 +218,99 @@ class TestDagRunState:
         response = client.post(f"/execution/dag-runs/{dag_id}/{run_id}/clear")
 
         assert response.status_code == 404
+
+
+class TestGetDagRunCount:
+    def setup_method(self):
+        clear_db_runs()
+
+    def teardown_method(self):
+        clear_db_runs()
+
+    def test_get_count_basic(self, client, session, dag_maker):
+        with dag_maker("test_dag"):
+            pass
+        dag_maker.create_dagrun()
+        session.commit()
+
+        response = client.get("/execution/dag-runs/count", params={"dag_id": 
"test_dag"})
+        assert response.status_code == 200
+        assert response.json() == 1
+
+    def test_get_count_with_states(self, client, session, dag_maker):
+        """Test counting DAG runs in specific states."""
+        with dag_maker("test_get_count_with_states"):
+            pass
+
+        # Create DAG runs with different states
+        dag_maker.create_dagrun(
+            state=State.SUCCESS, logical_date=timezone.datetime(2025, 1, 1), 
run_id="test_run_id1"
+        )
+        dag_maker.create_dagrun(
+            state=State.FAILED, logical_date=timezone.datetime(2025, 1, 2), 
run_id="test_run_id2"
+        )
+        dag_maker.create_dagrun(
+            state=State.RUNNING, logical_date=timezone.datetime(2025, 1, 3), 
run_id="test_run_id3"
+        )
+        session.commit()
+
+        response = client.get(
+            "/execution/dag-runs/count",
+            params={"dag_id": "test_get_count_with_states", "states": 
[State.SUCCESS, State.FAILED]},
+        )
+        assert response.status_code == 200
+        assert response.json() == 2
+
+    def test_get_count_with_logical_dates(self, client, session, dag_maker):
+        with dag_maker("test_get_count_with_logical_dates"):
+            pass
+
+        date1 = timezone.datetime(2025, 1, 1)
+        date2 = timezone.datetime(2025, 1, 2)
+
+        dag_maker.create_dagrun(run_id="test_run_id1", logical_date=date1)
+        dag_maker.create_dagrun(run_id="test_run_id2", logical_date=date2)
+        session.commit()
+
+        response = client.get(
+            "/execution/dag-runs/count",
+            params={
+                "dag_id": "test_get_count_with_logical_dates",
+                "logical_dates": [date1.isoformat(), date2.isoformat()],
+            },
+        )
+        assert response.status_code == 200
+        assert response.json() == 2
+
+    def test_get_count_with_run_ids(self, client, session, dag_maker):
+        with dag_maker("test_get_count_with_run_ids"):
+            pass
+
+        dag_maker.create_dagrun(run_id="run1", 
logical_date=timezone.datetime(2025, 1, 1))
+        dag_maker.create_dagrun(run_id="run2", 
logical_date=timezone.datetime(2025, 1, 2))
+        session.commit()
+
+        response = client.get(
+            "/execution/dag-runs/count",
+            params={"dag_id": "test_get_count_with_run_ids", "run_ids": 
["run1", "run2"]},
+        )
+        assert response.status_code == 200
+        assert response.json() == 2
+
+    def test_get_count_with_mixed_states(self, client, session, dag_maker):
+        with dag_maker("test_get_count_with_mixed"):
+            pass
+        dag_maker.create_dagrun(
+            state=State.SUCCESS, run_id="runid1", 
logical_date=timezone.datetime(2025, 1, 1)
+        )
+        dag_maker.create_dagrun(
+            state=State.QUEUED, run_id="runid2", 
logical_date=timezone.datetime(2025, 1, 2)
+        )
+        session.commit()
+
+        response = client.get(
+            "/execution/dag-runs/count",
+            params={"dag_id": "test_get_count_with_mixed", "states": 
[State.SUCCESS, State.QUEUED]},
+        )
+        assert response.status_code == 200
+        assert response.json() == 2
diff --git 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
index 81c99b271be..147209e967a 100644
--- 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
+++ 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
@@ -33,6 +33,7 @@ from airflow.models.asset import AssetActive, 
AssetAliasModel, AssetEvent, Asset
 from airflow.models.taskinstance import TaskInstance
 from airflow.models.taskinstancehistory import TaskInstanceHistory
 from airflow.providers.standard.operators.empty import EmptyOperator
+from airflow.sdk import TaskGroup
 from airflow.utils import timezone
 from airflow.utils.state import State, TaskInstanceState, TerminalTIState
 
@@ -1223,3 +1224,167 @@ class TestGetRescheduleStartDate:
         response = 
client.get(f"/execution/task-reschedules/{ti.id}/start_date?try_number=2")
         assert response.status_code == 200
         assert response.json() == "2024-01-02T00:00:00Z"
+
+
+class TestGetCount:
+    def setup_method(self):
+        clear_db_runs()
+
+    def teardown_method(self):
+        clear_db_runs()
+
+    def test_get_count_basic(self, client, session, create_task_instance):
+        create_task_instance(task_id="test_task", state=State.SUCCESS)
+        session.commit()
+
+        response = client.get("/execution/task-instances/count", 
params={"dag_id": "dag"})
+        assert response.status_code == 200
+        assert response.json() == 1
+
+    def test_get_count_with_task_ids(self, client, session, 
create_task_instance):
+        for i in range(3):
+            create_task_instance(
+                task_id=f"task{i}",
+                state=State.SUCCESS,
+                dag_id="test_get_count_with_task_ids",
+                run_id=f"test_run_id{i}",
+            )
+        session.commit()
+
+        response = client.get(
+            "/execution/task-instances/count",
+            params={"dag_id": "test_get_count_with_task_ids", "task_ids": 
["task1", "task2"]},
+        )
+        assert response.status_code == 200
+        assert response.json() == 2
+
+    def test_get_count_with_states(self, client, session, dag_maker):
+        """Test counting tasks in specific states."""
+        with dag_maker("test_get_count_with_states"):
+            EmptyOperator(task_id="task1")
+            EmptyOperator(task_id="task2")
+            EmptyOperator(task_id="task3")
+
+        dr = dag_maker.create_dagrun()
+
+        tis = dr.get_task_instances()
+
+        # Set different states for the task instances
+        for ti, state in zip(tis, [State.SUCCESS, State.FAILED, 
State.SKIPPED]):
+            ti.state = state
+            session.merge(ti)
+        session.commit()
+
+        response = client.get(
+            "/execution/task-instances/count",
+            params={"dag_id": "test_get_count_with_states", "states": 
[State.SUCCESS, State.FAILED]},
+        )
+        assert response.status_code == 200
+        assert response.json() == 2
+
+    def test_get_count_with_logical_dates(self, client, session, dag_maker):
+        with dag_maker("test_get_count_with_logical_dates"):
+            EmptyOperator(task_id="task1")
+
+        date1 = timezone.datetime(2025, 1, 1)
+        date2 = timezone.datetime(2025, 1, 2)
+
+        dag_maker.create_dagrun(run_id="test_run_id1", logical_date=date1)
+        dag_maker.create_dagrun(run_id="test_run_id2", logical_date=date2)
+
+        session.commit()
+
+        response = client.get(
+            "/execution/task-instances/count",
+            params={
+                "dag_id": "test_get_count_with_logical_dates",
+                "logical_dates": [date1.isoformat(), date2.isoformat()],
+            },
+        )
+        assert response.status_code == 200
+        assert response.json() == 2
+
+    def test_get_count_with_run_ids(self, client, session, dag_maker):
+        with dag_maker("test_get_count_with_run_ids"):
+            EmptyOperator(task_id="task1")
+
+        dag_maker.create_dagrun(run_id="run1", 
logical_date=timezone.datetime(2025, 1, 1))
+        dag_maker.create_dagrun(run_id="run2", 
logical_date=timezone.datetime(2025, 1, 2))
+
+        session.commit()
+
+        response = client.get(
+            "/execution/task-instances/count",
+            params={"dag_id": "test_get_count_with_run_ids", "run_ids": 
["run1", "run2"]},
+        )
+        assert response.status_code == 200
+        assert response.json() == 2
+
+    def test_get_count_with_task_group(self, client, session, dag_maker):
+        with dag_maker(dag_id="test_dag", serialized=True):
+            with TaskGroup("group1"):
+                EmptyOperator(task_id="task1")
+                EmptyOperator(task_id="task2")
+
+            with TaskGroup("group2"):
+                EmptyOperator(task_id="task3")
+
+        dag_maker.create_dagrun(session=session)
+        session.commit()
+
+        response = client.get(
+            "/execution/task-instances/count",
+            params={"dag_id": "test_dag", "task_group_id": "group1"},
+        )
+        assert response.status_code == 200
+        assert response.json() == 2
+
+    def test_get_count_task_group_not_found(self, client, session, dag_maker):
+        with dag_maker(dag_id="test_get_count_task_group_not_found", 
serialized=True):
+            with TaskGroup("group1"):
+                EmptyOperator(task_id="task1")
+        dag_maker.create_dagrun(session=session)
+
+        response = client.get(
+            "/execution/task-instances/count",
+            params={"dag_id": "test_get_count_task_group_not_found", 
"task_group_id": "non_existent_group"},
+        )
+        assert response.status_code == 404
+        assert response.json()["detail"] == {
+            "reason": "not_found",
+            "message": "Task group non_existent_group not found in DAG 
test_get_count_task_group_not_found",
+        }
+
+    def test_get_count_dag_not_found(self, client, session):
+        response = client.get(
+            "/execution/task-instances/count",
+            params={"dag_id": "non_existent_dag", "task_group_id": "group1"},
+        )
+        assert response.status_code == 404
+        assert response.json()["detail"] == {
+            "reason": "not_found",
+            "message": "DAG non_existent_dag not found",
+        }
+
+    def test_get_count_with_none_state(self, client, session, 
create_task_instance):
+        create_task_instance(task_id="task1", dag_id="get_count_with_none", 
state=None)
+        session.commit()
+
+        response = client.get(
+            "/execution/task-instances/count",
+            params={"dag_id": "get_count_with_none", "states": ["null"]},
+        )
+        assert response.status_code == 200
+        assert response.json() == 1
+
+    def test_get_count_with_mixed_states(self, client, session, 
create_task_instance):
+        create_task_instance(task_id="task1", state=State.SUCCESS, 
run_id="runid1", dag_id="mixed_states")
+        create_task_instance(task_id="task2", state=None, run_id="runid2", 
dag_id="mixed_states")
+        session.commit()
+
+        response = client.get(
+            "/execution/task-instances/count",
+            params={"dag_id": "mixed_states", "states": [State.SUCCESS, 
"null"]},
+        )
+        assert response.status_code == 200
+        assert response.json() == 2
diff --git a/devel-common/src/tests_common/pytest_plugin.py 
b/devel-common/src/tests_common/pytest_plugin.py
index 3cc66ac9f96..b92c27141d7 100644
--- a/devel-common/src/tests_common/pytest_plugin.py
+++ b/devel-common/src/tests_common/pytest_plugin.py
@@ -2318,3 +2318,14 @@ def run_task(create_runtime_ti, mock_supervisor_comms, 
spy_agency) -> RunTaskCal
 def mock_xcom_backend():
     with mock.patch("airflow.sdk.execution_time.task_runner.XCom", 
create=True) as xcom_backend:
         yield xcom_backend
+
+
[email protected]
+def testing_dag_bundle():
+    from airflow.models.dagbundle import DagBundleModel
+    from airflow.utils.session import create_session
+
+    with create_session() as session:
+        if session.query(DagBundleModel).filter(DagBundleModel.name == 
"testing").count() == 0:
+            testing = DagBundleModel(name="testing")
+            session.add(testing)
diff --git 
a/providers/standard/src/airflow/providers/standard/sensors/external_task.py 
b/providers/standard/src/airflow/providers/standard/sensors/external_task.py
index 0a34cc5d48f..dd5ad6c4e7a 100644
--- a/providers/standard/src/airflow/providers/standard/sensors/external_task.py
+++ b/providers/standard/src/airflow/providers/standard/sensors/external_task.py
@@ -25,27 +25,32 @@ from typing import TYPE_CHECKING, Any, Callable, ClassVar
 
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException, AirflowSkipException
-from airflow.models.baseoperator import BaseOperator
 from airflow.models.dag import DagModel
 from airflow.models.dagbag import DagBag
 from airflow.providers.standard.operators.empty import EmptyOperator
 from airflow.providers.standard.triggers.external_task import WorkflowTrigger
 from airflow.providers.standard.utils.sensor_helper import _get_count, 
_get_external_task_group_task_ids
 from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
-from airflow.sensors.base import BaseSensorOperator
 from airflow.utils.file import correct_maybe_zipped
 from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.state import State, TaskInstanceState
 
+if AIRFLOW_V_3_0_PLUS:
+    from airflow.sdk.bases.sensor import BaseSensorOperator
+else:
+    from airflow.sensors.base import BaseSensorOperator
+
 if TYPE_CHECKING:
     from sqlalchemy.orm import Session
 
     from airflow.models.taskinstancekey import TaskInstanceKey
 
     try:
+        from airflow.sdk import BaseOperator
         from airflow.sdk.definitions.context import Context
     except ImportError:
         # TODO: Remove once provider drops support for Airflow 2
+        from airflow.models.baseoperator import BaseOperator
         from airflow.utils.context import Context
 
 
@@ -65,15 +70,16 @@ class ExternalDagLink(BaseOperatorLink):
     name = "External DAG"
 
     def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> 
str:
-        from airflow.models.renderedtifields import RenderedTaskInstanceFields
-
         if TYPE_CHECKING:
             assert isinstance(operator, (ExternalTaskMarker, 
ExternalTaskSensor))
 
-        if template_fields := 
RenderedTaskInstanceFields.get_templated_fields(ti_key):
-            external_dag_id: str = template_fields.get("external_dag_id", 
operator.external_dag_id)
-        else:
-            external_dag_id = operator.external_dag_id
+        external_dag_id = operator.external_dag_id
+
+        if not AIRFLOW_V_3_0_PLUS:
+            from airflow.models.renderedtifields import 
RenderedTaskInstanceFields
+
+            if template_fields := 
RenderedTaskInstanceFields.get_templated_fields(ti_key):
+                external_dag_id: str = template_fields.get("external_dag_id", 
operator.external_dag_id)  # type: ignore[no-redef]
 
         if AIRFLOW_V_3_0_PLUS:
             from airflow.utils.helpers import build_airflow_dagrun_url
@@ -86,9 +92,7 @@ class ExternalDagLink(BaseOperatorLink):
             return build_airflow_url_with_query(query)
 
 
-# TODO: Remove BaseOperator from inheritance in 
https://github.com/apache/airflow/issues/47447
-#   It is only temporary until we refactor the code to not directly go to the 
DB.
-class ExternalTaskSensor(BaseSensorOperator, BaseOperator):
+class ExternalTaskSensor(BaseSensorOperator):
     """
     Waits for a different DAG, task group, or task to complete for a specific 
logical date.
 
@@ -247,16 +251,22 @@ class ExternalTaskSensor(BaseSensorOperator, 
BaseOperator):
         self.poll_interval = poll_interval
 
     def _get_dttm_filter(self, context):
+        logical_date = context.get("logical_date")
+        if logical_date is None:
+            dag_run = context.get("dag_run")
+            if TYPE_CHECKING:
+                assert dag_run
+
+            logical_date = dag_run.run_after
         if self.execution_delta:
-            dttm = context["logical_date"] - self.execution_delta
+            dttm = logical_date - self.execution_delta
         elif self.execution_date_fn:
             dttm = self._handle_execution_date_fn(context=context)
         else:
-            dttm = context["logical_date"]
+            dttm = logical_date
         return dttm if isinstance(dttm, list) else [dttm]
 
-    @provide_session
-    def poke(self, context: Context, session: Session = NEW_SESSION) -> bool:
+    def poke(self, context: Context) -> bool:
         # delay check to poke rather than __init__ in case it was supplied as 
XComArgs
         if self.external_task_ids and len(self.external_task_ids) > 
len(set(self.external_task_ids)):
             raise ValueError("Duplicate task_ids passed in external_task_ids 
parameter")
@@ -287,15 +297,62 @@ class ExternalTaskSensor(BaseSensorOperator, 
BaseOperator):
                 serialized_dttm_filter,
             )
 
-        # In poke mode this will check dag existence only once
-        if self.check_existence and not self._has_checked_existence:
-            self._check_for_existence(session=session)
+        if AIRFLOW_V_3_0_PLUS:
+            return self._poke_af3(context, dttm_filter)
+        else:
+            return self._poke_af2(dttm_filter)
+
+    def _poke_af3(self, context: Context, dttm_filter: 
list[datetime.datetime]) -> bool:
+        self._has_checked_existence = True
+        ti = context["ti"]
+
+        def _get_count(states: list[str]) -> int:
+            if self.external_task_ids:
+                return ti.get_ti_count(
+                    dag_id=self.external_dag_id,
+                    task_ids=self.external_task_ids,  # type: ignore[arg-type]
+                    logical_dates=dttm_filter,
+                    states=states,
+                )
+            elif self.external_task_group_id:
+                return ti.get_ti_count(
+                    dag_id=self.external_dag_id,
+                    task_group_id=self.external_task_group_id,
+                    logical_dates=dttm_filter,
+                    states=states,
+                )
+            else:
+                return ti.get_dr_count(
+                    dag_id=self.external_dag_id,
+                    logical_dates=dttm_filter,
+                    states=states,
+                )
 
-        count_failed = -1
         if self.failed_states:
-            count_failed = self.get_count(dttm_filter, session, 
self.failed_states)
+            count = _get_count(self.failed_states)
+            count_failed = self._calculate_count(count, dttm_filter)
+            self._handle_failed_states(count_failed)
 
-        # Fail if anything in the list has failed.
+        if self.skipped_states:
+            count = _get_count(self.skipped_states)
+            count_skipped = self._calculate_count(count, dttm_filter)
+            self._handle_skipped_states(count_skipped)
+
+        count = _get_count(self.allowed_states)
+        count_allowed = self._calculate_count(count, dttm_filter)
+        return count_allowed == len(dttm_filter)
+
+    def _calculate_count(self, count: int, dttm_filter: 
list[datetime.datetime]) -> float | int:
+        """Calculate the normalized count based on the type of check."""
+        if self.external_task_ids:
+            return count / len(self.external_task_ids)
+        elif self.external_task_group_id:
+            return count / len(dttm_filter)
+        else:
+            return count
+
+    def _handle_failed_states(self, count_failed: float | int) -> None:
+        """Handle failed states and raise appropriate exceptions."""
         if count_failed > 0:
             if self.external_task_ids:
                 if self.soft_fail:
@@ -317,7 +374,6 @@ class ExternalTaskSensor(BaseSensorOperator, BaseOperator):
                     f"The external task_group '{self.external_task_group_id}' "
                     f"in DAG '{self.external_dag_id}' failed."
                 )
-
             else:
                 if self.soft_fail:
                     raise AirflowSkipException(
@@ -325,12 +381,8 @@ class ExternalTaskSensor(BaseSensorOperator, BaseOperator):
                     )
                 raise AirflowException(f"The external DAG 
{self.external_dag_id} failed.")
 
-        count_skipped = -1
-        if self.skipped_states:
-            count_skipped = self.get_count(dttm_filter, session, 
self.skipped_states)
-
-        # Skip if anything in the list has skipped. Note if we are checking 
multiple tasks and one skips
-        # before another errors, we'll skip first.
+    def _handle_skipped_states(self, count_skipped: float | int) -> None:
+        """Handle skipped states and raise appropriate exceptions."""
         if count_skipped > 0:
             if self.external_task_ids:
                 raise AirflowSkipException(
@@ -348,7 +400,19 @@ class ExternalTaskSensor(BaseSensorOperator, BaseOperator):
                     "Skipping."
                 )
 
-        # only go green if every single task has reached an allowed state
+    @provide_session
+    def _poke_af2(self, dttm_filter: list[datetime.datetime], session: Session 
= NEW_SESSION) -> bool:
+        if self.check_existence and not self._has_checked_existence:
+            self._check_for_existence(session=session)
+
+        if self.failed_states:
+            count_failed = self.get_count(dttm_filter, session, 
self.failed_states)
+            self._handle_failed_states(count_failed)
+
+        if self.skipped_states:
+            count_skipped = self.get_count(dttm_filter, session, 
self.skipped_states)
+            self._handle_skipped_states(count_skipped)
+
         count_allowed = self.get_count(dttm_filter, session, 
self.allowed_states)
         return count_allowed == len(dttm_filter)
 
@@ -483,6 +547,9 @@ class ExternalTaskMarker(EmptyOperator):
     """
 
     template_fields = ["external_dag_id", "external_task_id", "logical_date"]
+    if not AIRFLOW_V_3_0_PLUS:
+        template_fields.append("execution_date")
+
     ui_color = "#4db7db"
     operator_extra_links = [ExternalDagLink()]
 
@@ -510,6 +577,9 @@ class ExternalTaskMarker(EmptyOperator):
                 f"Expected str or datetime.datetime type for logical_date. Got 
{type(logical_date)}"
             )
 
+        if not AIRFLOW_V_3_0_PLUS:
+            self.execution_date = self.logical_date
+
         if recursion_depth <= 0:
             raise ValueError("recursion_depth should be a positive integer")
         self.recursion_depth = recursion_depth
diff --git 
a/providers/standard/src/airflow/providers/standard/utils/sensor_helper.py 
b/providers/standard/src/airflow/providers/standard/utils/sensor_helper.py
index ae5d3c12985..17d54e371bc 100644
--- a/providers/standard/src/airflow/providers/standard/utils/sensor_helper.py
+++ b/providers/standard/src/airflow/providers/standard/utils/sensor_helper.py
@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, cast
 from sqlalchemy import func, select, tuple_
 
 from airflow.models import DagBag, DagRun, TaskInstance
+from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
 from airflow.utils.session import NEW_SESSION, provide_session
 
 if TYPE_CHECKING:
@@ -88,8 +89,10 @@ def _count_stmt(model, states, dttm_filter, external_dag_id) 
-> Executable:
     :param dttm_filter: date time filter for logical date
     :param external_dag_id: The ID of the external DAG.
     """
+    date_field = model.logical_date if AIRFLOW_V_3_0_PLUS else 
model.execution_date
+
     return select(func.count()).where(
-        model.dag_id == external_dag_id, model.state.in_(states), 
model.logical_date.in_(dttm_filter)
+        model.dag_id == external_dag_id, model.state.in_(states), 
date_field.in_(dttm_filter)
     )
 
 
@@ -106,11 +109,13 @@ def _get_external_task_group_task_ids(dttm_filter, 
external_task_group_id, exter
     task_group = refreshed_dag_info.task_group_dict.get(external_task_group_id)
 
     if task_group:
+        date_field = TaskInstance.logical_date if AIRFLOW_V_3_0_PLUS else 
TaskInstance.execution_date
+
         group_tasks = session.scalars(
             select(TaskInstance).filter(
                 TaskInstance.dag_id == external_dag_id,
                 TaskInstance.task_id.in_(task.task_id for task in task_group),
-                TaskInstance.logical_date.in_(dttm_filter),
+                date_field.in_(dttm_filter),
             )
         )
 
diff --git a/airflow-core/tests/unit/sensors/test_external_task_sensor.py 
b/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py
similarity index 79%
rename from airflow-core/tests/unit/sensors/test_external_task_sensor.py
rename to 
providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py
index 1a7938cc27d..95deddd4ade 100644
--- a/airflow-core/tests/unit/sensors/test_external_task_sensor.py
+++ 
b/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py
@@ -19,10 +19,7 @@ from __future__ import annotations
 
 import itertools
 import logging
-import os
 import re
-import tempfile
-import zipfile
 from datetime import time, timedelta
 from unittest import mock
 
@@ -47,8 +44,7 @@ from airflow.providers.standard.sensors.time import TimeSensor
 from airflow.providers.standard.triggers.external_task import WorkflowTrigger
 from airflow.serialization.serialized_objects import SerializedBaseOperator
 from airflow.timetables.base import DataInterval
-from airflow.utils.hashlib_wrapper import md5
-from airflow.utils.session import NEW_SESSION, create_session, provide_session
+from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.task_group import TaskGroup
 from airflow.utils.timezone import coerce_datetime, datetime
@@ -57,7 +53,6 @@ from airflow.utils.types import DagRunType
 from tests_common.test_utils.db import clear_db_runs
 from tests_common.test_utils.mock_operators import MockOperator
 from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
-from unit.models import TEST_DAGS_FOLDER
 
 if AIRFLOW_V_3_0_PLUS:
     from airflow.utils.types import DagRunTriggeredByType
@@ -81,42 +76,13 @@ def clean_db():
     clear_db_runs()
 
 
[email protected]
-def dag_zip_maker(testing_dag_bundle):
-    class DagZipMaker:
-        def __call__(self, *dag_files):
-            self.__dag_files = [os.sep.join([TEST_DAGS_FOLDER.__str__(), 
dag_file]) for dag_file in dag_files]
-            dag_files_hash = 
md5("".join(self.__dag_files).encode()).hexdigest()
-            self.__tmp_dir = os.sep.join([tempfile.tempdir, dag_files_hash])
-
-            self.__zip_file_name = os.sep.join([self.__tmp_dir, 
f"{dag_files_hash}.zip"])
-
-            if not os.path.exists(self.__tmp_dir):
-                os.mkdir(self.__tmp_dir)
-            return self
-
-        def __enter__(self):
-            with zipfile.ZipFile(self.__zip_file_name, "x") as zf:
-                for dag_file in self.__dag_files:
-                    zf.write(dag_file, os.path.basename(dag_file))
-            dagbag = DagBag(dag_folder=self.__tmp_dir, include_examples=False)
-            dagbag.sync_to_db("testing", None)
-            return dagbag
-
-        def __exit__(self, exc_type, exc_val, exc_tb):
-            os.unlink(self.__zip_file_name)
-            os.rmdir(self.__tmp_dir)
-
-    return DagZipMaker()
-
-
[email protected]("testing_dag_bundle")
-class TestExternalTaskSensor:
[email protected](AIRFLOW_V_3_0_PLUS, reason="Different test for v3.0+")
+class TestExternalTaskSensorV2:
     def setup_method(self):
         self.dagbag = DagBag(dag_folder=DEV_NULL, include_examples=True)
         self.args = {"owner": "airflow", "start_date": DEFAULT_DATE}
         self.dag = DAG(TEST_DAG_ID, schedule=None, default_args=self.args)
-        self.dag_run_id = 
DagRunType.MANUAL.generate_run_id(suffix=DEFAULT_DATE.isoformat())
+        self.dag_run_id = DagRunType.MANUAL.generate_run_id(DEFAULT_DATE)
 
     def add_time_sensor(self, task_id=TEST_TASK_ID):
         # TODO: Remove BaseOperator in 
https://github.com/apache/airflow/issues/47447
@@ -133,7 +99,7 @@ class TestExternalTaskSensor:
             with TaskGroup(group_id=TEST_TASK_GROUP_ID) as task_group:
                 _ = [EmptyOperator(task_id=f"task{i}") for i in 
range(len(target_states))]
             dag.sync_to_db()
-            SerializedDagModel.write_dag(dag, bundle_name="test_bundle")
+            SerializedDagModel.write_dag(dag)
 
         for idx, task in enumerate(task_group):
             ti = TaskInstance(task=task, run_id=self.dag_run_id)
@@ -156,7 +122,7 @@ class TestExternalTaskSensor:
                 fake_task()
                 fake_mapped_task.expand(x=list(map_indexes))
         dag.sync_to_db()
-        SerializedDagModel.write_dag(dag, bundle_name="test_bundle")
+        SerializedDagModel.write_dag(dag)
 
         for task in task_group:
             if task.task_id == "fake_mapped_task":
@@ -530,7 +496,7 @@ exit 0
                 .filter(
                     TI.dag_id == dag_external_id,
                     TI.state == State.FAILED,
-                    TI.logical_date == DEFAULT_DATE + timedelta(seconds=1),
+                    TI.execution_date == DEFAULT_DATE + timedelta(seconds=1),
                 )
                 .all()
             )
@@ -977,10 +943,301 @@ exit 0
             check_existence=True,
             **kwargs,
         )
+        if not hasattr(op, "never_fail"):
+            expected_message = "Skipping due to soft_fail is set to True." if 
soft_fail else expected_message
         with pytest.raises(expected_exception, match=expected_message):
             op.execute(context={})
 
 
[email protected](not AIRFLOW_V_3_0_PLUS, reason="Different test for AF 2")
[email protected]("testing_dag_bundle")
+class TestExternalTaskSensorV3:
+    def setup_method(self):
+        # Create a mock for TaskInstance with get_ti_count method
+        mock_ti = mock.MagicMock()
+        mock_ti.get_ti_count = mock.MagicMock(return_value=0)  # Default 
return value
+
+        self.context = {
+            "execution_date": DEFAULT_DATE,
+            "logical_date": DEFAULT_DATE,
+            "ti": mock_ti,
+            "task": mock.MagicMock(),
+            "run_id": "test_run_id",
+        }
+
+    @pytest.mark.execution_timeout(10)
+    def test_external_task_sensor_success(self, dag_maker):
+        """Test that the sensor succeeds when the external task succeeds."""
+        with dag_maker("test_dag_child"):
+            op = ExternalTaskSensor(
+                task_id="test_external_task_sensor_check",
+                external_dag_id="test_dag_parent",
+                external_task_id="test_external_task_sensor_success",
+                allowed_states=["success"],
+            )
+
+        # Mimic DB response to get_ti_count as 1
+        self.context["ti"].get_ti_count.return_value = 1
+
+        op.execute(context=self.context)
+
+        self.context["ti"].get_ti_count.assert_called_once_with(
+            dag_id="test_dag_parent",
+            logical_dates=[DEFAULT_DATE],
+            states=["success"],
+            task_ids=["test_external_task_sensor_success"],
+        )
+
+    @pytest.mark.execution_timeout(10)
+    def test_external_task_sensor_failure(self, dag_maker):
+        """Test that the sensor fails when the external task fails."""
+        with dag_maker("test_dag_child"):
+            op = ExternalTaskSensor(
+                task_id="test_external_task_sensor_check",
+                external_dag_id="test_dag_parent",
+                external_task_id="test_external_task_sensor_failure",
+                failed_states=[State.FAILED],
+            )
+
+        self.context["ti"].get_ti_count.return_value = 1
+
+        with pytest.raises(AirflowException):
+            op.execute(context=self.context)
+
+        self.context["ti"].get_ti_count.assert_called_once_with(
+            dag_id="test_dag_parent",
+            logical_dates=[DEFAULT_DATE],
+            states=[State.FAILED],
+            task_ids=["test_external_task_sensor_failure"],
+        )
+
+    @pytest.mark.execution_timeout(10)
+    def test_external_task_sensor_soft_fail(self, dag_maker):
+        """Test that the sensor skips when soft_fail is True and external task 
fails."""
+        with dag_maker("test_dag_child"):
+            op = ExternalTaskSensor(
+                task_id="test_external_task_sensor_check",
+                external_dag_id="test_dag_parent",
+                external_task_id="test_external_task_sensor_soft_fail",
+                failed_states=[State.FAILED],
+                soft_fail=True,
+            )
+
+        self.context["ti"].get_ti_count.return_value = 1
+
+        with pytest.raises(AirflowSkipException):
+            op.execute(context=self.context)
+
+        self.context["ti"].get_ti_count.assert_called_once_with(
+            dag_id="test_dag_parent",
+            logical_dates=[DEFAULT_DATE],
+            states=[State.FAILED],
+            task_ids=["test_external_task_sensor_soft_fail"],
+        )
+
+    @pytest.mark.execution_timeout(10)
+    def test_external_task_sensor_multiple_task_ids(self, dag_maker):
+        with dag_maker("test_dag_child"):
+            op = ExternalTaskSensor(
+                task_id="test_external_task_sensor_check",
+                external_dag_id="test_dag_parent",
+                external_task_ids=["task1", "task2"],
+                allowed_states=["success"],
+            )
+
+        self.context["ti"].get_ti_count.return_value = 2
+        op.execute(context=self.context)
+
+        self.context["ti"].get_ti_count.assert_called_once_with(
+            dag_id="test_dag_parent",
+            logical_dates=[DEFAULT_DATE],
+            states=["success"],
+            task_ids=["task1", "task2"],
+        )
+
+    @pytest.mark.execution_timeout(10)
+    def test_external_task_sensor_skipped_states(self, dag_maker):
+        with dag_maker("test_dag_child"):
+            op = ExternalTaskSensor(
+                task_id="test_external_task_sensor_check",
+                external_dag_id="test_dag_parent",
+                external_task_id="test_external_task_sensor_skipped_states",
+                skipped_states=[State.SKIPPED],
+            )
+
+        self.context["ti"].get_ti_count.return_value = 1
+        with pytest.raises(AirflowSkipException):
+            op.execute(context=self.context)
+
+        self.context["ti"].get_ti_count.assert_called_once_with(
+            dag_id="test_dag_parent",
+            logical_dates=[DEFAULT_DATE],
+            states=[State.SKIPPED],
+            task_ids=["test_external_task_sensor_skipped_states"],
+        )
+
+    def test_external_task_sensor_invalid_combination(self, dag_maker):
+        """Test that the sensor raises an error with invalid parameter 
combinations."""
+        with pytest.raises(ValueError):
+            with dag_maker("test_external_task_sensor_invalid_combination"):
+                ExternalTaskSensor(
+                    task_id="test_external_task_sensor_check",
+                    external_dag_id="test_dag",
+                    external_task_id="test_task",
+                    external_task_ids=["test_task"],
+                )
+
+    def test_external_task_sensor_invalid_state(self, dag_maker):
+        with pytest.raises(ValueError):
+            with dag_maker("test_external_task_sensor_invalid_state"):
+                ExternalTaskSensor(
+                    task_id="test_external_task_sensor_check",
+                    external_dag_id="test_dag",
+                    external_task_id="test_task",
+                    allowed_states=["invalid_state"],
+                )
+
+    @pytest.mark.execution_timeout(10)
+    def test_external_task_sensor_task_group(self, dag_maker):
+        with dag_maker("test_dag_child"):
+            op = ExternalTaskSensor(
+                task_id="test_external_task_sensor_check",
+                external_dag_id="test_dag_parent",
+                external_task_group_id="test_group",
+                allowed_states=["success"],
+            )
+
+        self.context["ti"].get_ti_count.return_value = 1
+        op.execute(context=self.context)
+
+        self.context["ti"].get_ti_count.assert_called_once_with(
+            dag_id="test_dag_parent",
+            logical_dates=[DEFAULT_DATE],
+            states=["success"],
+            task_group_id="test_group",
+        )
+
+    @pytest.mark.execution_timeout(10)
+    def test_external_task_sensor_execution_date_fn(self, dag_maker):
+        def execution_date_fn(dt):
+            return [dt + timedelta(hours=1)]
+
+        with dag_maker("test_dag_child"):
+            op = ExternalTaskSensor(
+                task_id="test_external_task_sensor_check",
+                external_dag_id="test_dag_parent",
+                external_task_id="test_task",
+                execution_date_fn=execution_date_fn,
+                allowed_states=["success"],
+            )
+
+        self.context["ti"].get_ti_count.return_value = 1
+        op.execute(context=self.context)
+
+        expected_date = DEFAULT_DATE + timedelta(hours=1)
+        self.context["ti"].get_ti_count.assert_called_once_with(
+            dag_id="test_dag_parent",
+            logical_dates=[expected_date],
+            states=["success"],
+            task_ids=["test_task"],
+        )
+
+    @pytest.mark.execution_timeout(10)
+    def test_external_task_sensor_execution_delta(self, dag_maker):
+        with dag_maker("test_dag_child"):
+            op = ExternalTaskSensor(
+                task_id="test_external_task_sensor_check",
+                external_dag_id="test_dag_parent",
+                external_task_id="test_task",
+                execution_delta=timedelta(hours=1),
+                allowed_states=["success"],
+            )
+
+        self.context["ti"].get_ti_count.return_value = 1
+        op.execute(context=self.context)
+
+        expected_date = DEFAULT_DATE - timedelta(hours=1)
+        self.context["ti"].get_ti_count.assert_called_once_with(
+            dag_id="test_dag_parent",
+            logical_dates=[expected_date],
+            states=["success"],
+            task_ids=["test_task"],
+        )
+
+    @pytest.mark.execution_timeout(10)
+    def test_external_task_sensor_duplicate_task_ids(self, dag_maker):
+        with dag_maker("test_dag_child"):
+            op = ExternalTaskSensor(
+                task_id="test_external_task_sensor_check",
+                external_dag_id="test_dag_parent",
+                external_task_ids=["task1", "task1"],
+                allowed_states=["success"],
+            )
+
+        with pytest.raises(ValueError, match="Duplicate task_ids passed in 
external_task_ids parameter"):
+            op.execute(context=self.context)
+
+    @pytest.mark.execution_timeout(10)
+    def test_external_task_sensor_deferrable(self, dag_maker):
+        with dag_maker("test_dag_child"):
+            op = ExternalTaskSensor(
+                task_id="test_external_task_sensor_check",
+                external_dag_id="test_dag_parent",
+                external_task_id="test_task",
+                deferrable=True,
+                allowed_states=["success"],
+            )
+
+        with pytest.raises(TaskDeferred) as exc:
+            op.execute(context=self.context)
+
+        assert isinstance(exc.value.trigger, WorkflowTrigger)
+        assert exc.value.trigger.external_dag_id == "test_dag_parent"
+        assert exc.value.trigger.external_task_ids == ["test_task"]
+
+    @pytest.mark.execution_timeout(10)
+    def test_external_task_sensor_only_dag_id(self, dag_maker):
+        """Test that the sensor works correctly when only external_dag_id is 
provided."""
+        with dag_maker("test_dag_child"):
+            op = ExternalTaskSensor(
+                task_id="test_external_task_sensor_check",
+                external_dag_id="test_dag_parent",
+                allowed_states=["success"],
+            )
+
+        self.context["ti"].get_dr_count = mock.MagicMock(return_value=1)
+
+        op.execute(context=self.context)
+
+        self.context["ti"].get_dr_count.assert_called_once_with(
+            dag_id="test_dag_parent",
+            logical_dates=[DEFAULT_DATE],
+            states=["success"],
+        )
+
+    @pytest.mark.execution_timeout(10)
+    def test_external_task_sensor_task_group_failed_states(self, dag_maker):
+        with dag_maker("test_dag_child"):
+            op = ExternalTaskSensor(
+                task_id="test_external_task_sensor_check",
+                external_dag_id="test_dag_parent",
+                external_task_group_id="test_group",
+                failed_states=[State.FAILED],
+            )
+
+        self.context["ti"].get_ti_count.return_value = 1
+
+        with pytest.raises(AirflowException):
+            op.execute(context=self.context)
+
+        self.context["ti"].get_ti_count.assert_called_once_with(
+            dag_id="test_dag_parent",
+            logical_dates=[DEFAULT_DATE],
+            states=[State.FAILED],
+            task_group_id="test_group",
+        )
+
+
 class TestExternalTaskAsyncSensor:
     TASK_ID = "external_task_sensor_check"
     EXTERNAL_DAG_ID = "child_dag"  # DAG the external task sensor is waiting on
@@ -1050,14 +1307,7 @@ class TestExternalTaskAsyncSensor:
         mock_log_info.assert_called_with("External tasks %s has executed 
successfully.", [EXTERNAL_TASK_ID])
 
 
-def test_external_task_sensor_check_zipped_dag_existence(dag_zip_maker):
-    with dag_zip_maker("test_external_task_sensor_check_existense.py") as 
dagbag:
-        with create_session() as session:
-            dag = dagbag.dags["test_external_task_sensor_check_existence"]
-            op = dag.tasks[0]
-            op._check_for_existence(session)
-
-
[email protected](not AIRFLOW_V_3_0_PLUS, reason="Needs Flask app context 
fixture for AF 2")
 @pytest.mark.parametrize(
     argnames=["external_dag_id", "external_task_id", 
"expected_external_dag_id", "expected_external_task_id"],
     argvalues=[
@@ -1166,7 +1416,10 @@ def dag_bag_ext():
     task_a_3 >> task_b_3
 
     for dag in [dag_0, dag_1, dag_2, dag_3]:
-        dag_bag.bag_dag(dag=dag)
+        if AIRFLOW_V_3_0_PLUS:
+            dag_bag.bag_dag(dag=dag)
+        else:
+            dag_bag.bag_dag(dag=dag, root_dag=dag)
 
     yield dag_bag
 
@@ -1215,7 +1468,10 @@ def dag_bag_parent_child():
         )
 
     for dag in [dag_0, dag_1]:
-        dag_bag.bag_dag(dag=dag)
+        if AIRFLOW_V_3_0_PLUS:
+            dag_bag.bag_dag(dag=dag)
+        else:
+            dag_bag.bag_dag(dag=dag, root_dag=dag)
 
     yield dag_bag
 
@@ -1237,22 +1493,37 @@ def run_tasks(
 
     for dag in dag_bag.dags.values():
         data_interval = DataInterval(coerce_datetime(logical_date), 
coerce_datetime(logical_date))
-        runs[dag.dag_id] = dagrun = dag.create_dagrun(
-            run_id=dag.timetable.generate_run_id(
-                run_type=DagRunType.MANUAL,
+        if AIRFLOW_V_3_0_PLUS:
+            runs[dag.dag_id] = dagrun = dag.create_dagrun(
+                run_id=dag.timetable.generate_run_id(
+                    run_type=DagRunType.MANUAL,
+                    run_after=logical_date,
+                    data_interval=data_interval,
+                ),
+                logical_date=logical_date,
+                data_interval=data_interval,
                 run_after=logical_date,
+                run_type=DagRunType.MANUAL,
+                triggered_by=DagRunTriggeredByType.TEST,
+                dag_version=None,
+                state=DagRunState.RUNNING,
+                start_date=logical_date,
+                session=session,
+            )
+        else:
+            runs[dag.dag_id] = dagrun = dag.create_dagrun(  # type: 
ignore[call-arg]
+                run_id=dag.timetable.generate_run_id(  # type: ignore[call-arg]
+                    run_type=DagRunType.MANUAL,
+                    logical_date=logical_date,
+                    data_interval=data_interval,
+                ),
+                execution_date=logical_date,
                 data_interval=data_interval,
-            ),
-            logical_date=logical_date,
-            data_interval=data_interval,
-            run_after=logical_date,
-            run_type=DagRunType.MANUAL,
-            triggered_by=DagRunTriggeredByType.TEST,
-            dag_version=None,
-            state=DagRunState.RUNNING,
-            start_date=logical_date,
-            session=session,
-        )
+                run_type=DagRunType.MANUAL,
+                state=DagRunState.RUNNING,
+                start_date=logical_date,
+                session=session,
+            )
         # we use sorting by task_id here because for the test DAG structure of 
ours
         # this is equivalent to topological sort. It would not work in general 
case
         # but it works for our case because we specifically constructed test 
DAGS
@@ -1290,7 +1561,7 @@ def clear_tasks(
     """
     Clear the task and its downstream tasks recursively for the dag in the 
given dagbag.
     """
-    partial: DAG = dag.partial_subset(task_ids=[task.task_id], 
include_downstream=True)
+    partial: DAG = dag.partial_subset(task_ids_or_regex=[task.task_id], 
include_downstream=True)
     return partial.clear(
         start_date=start_date,
         end_date=end_date,
@@ -1300,6 +1571,7 @@ def clear_tasks(
     )
 
 
[email protected](AIRFLOW_V_3_0_PLUS, reason="Different test for 3.0+")
 def test_external_task_marker_transitive(dag_bag_ext):
     """
     Test clearing tasks across DAGs.
@@ -1314,6 +1586,7 @@ def test_external_task_marker_transitive(dag_bag_ext):
     assert_ti_state_equal(ti_b_3, State.NONE)
 
 
[email protected](AIRFLOW_V_3_0_PLUS, reason="Different test for 3.0+")
 @provide_session
 def test_external_task_marker_clear_activate(dag_bag_parent_child, session):
     """
@@ -1326,22 +1599,9 @@ def 
test_external_task_marker_clear_activate(dag_bag_parent_child, session):
     run_tasks(dag_bag, logical_date=day_1)
     run_tasks(dag_bag, logical_date=day_2)
 
-    from sqlalchemy import select
-
-    run_ids = []
     # Assert that dagruns of all the affected dags are set to SUCCESS before 
tasks are cleared.
-    for dag, logical_date in itertools.product(dag_bag.dags.values(), [day_1, 
day_2]):
-        run_id = (
-            select(DagRun.run_id)
-            .where(DagRun.logical_date == logical_date)
-            .order_by(DagRun.id.desc())
-            .limit(1)
-        )
-        run_ids.append(run_id)
-        dagrun = dag.get_dagrun(
-            run_id=run_id,
-            session=session,
-        )
+    for dag, execution_date in itertools.product(dag_bag.dags.values(), 
[day_1, day_2]):
+        dagrun = dag.get_dagrun(execution_date=execution_date, session=session)
         dagrun.set_state(State.SUCCESS)
     session.flush()
 
@@ -1351,10 +1611,10 @@ def 
test_external_task_marker_clear_activate(dag_bag_parent_child, session):
 
     # Assert that dagruns of all the affected dags are set to QUEUED after 
tasks are cleared.
     # Unaffected dagruns should be left as SUCCESS.
-    dagrun_0_1 = dag_bag.get_dag("parent_dag_0").get_dagrun(run_id=run_ids[0], 
session=session)
-    dagrun_0_2 = dag_bag.get_dag("parent_dag_0").get_dagrun(run_id=run_ids[1], 
session=session)
-    dagrun_1_1 = dag_bag.get_dag("child_dag_1").get_dagrun(run_id=run_ids[2], 
session=session)
-    dagrun_1_2 = dag_bag.get_dag("child_dag_1").get_dagrun(run_id=run_ids[3], 
session=session)
+    dagrun_0_1 = 
dag_bag.get_dag("parent_dag_0").get_dagrun(execution_date=day_1, 
session=session)
+    dagrun_0_2 = 
dag_bag.get_dag("parent_dag_0").get_dagrun(execution_date=day_2, 
session=session)
+    dagrun_1_1 = 
dag_bag.get_dag("child_dag_1").get_dagrun(execution_date=day_1, session=session)
+    dagrun_1_2 = 
dag_bag.get_dag("child_dag_1").get_dagrun(execution_date=day_2, session=session)
 
     assert dagrun_0_1.state == State.QUEUED
     assert dagrun_0_2.state == State.QUEUED
@@ -1362,6 +1622,7 @@ def 
test_external_task_marker_clear_activate(dag_bag_parent_child, session):
     assert dagrun_1_2.state == State.SUCCESS
 
 
[email protected](AIRFLOW_V_3_0_PLUS, reason="Different test for 3.0+")
 def test_external_task_marker_future(dag_bag_ext):
     """
     Test clearing tasks with no end_date. This is the case when users clear 
tasks with
@@ -1386,6 +1647,7 @@ def test_external_task_marker_future(dag_bag_ext):
     assert_ti_state_equal(ti_b_3_date_1, State.NONE)
 
 
[email protected](AIRFLOW_V_3_0_PLUS, reason="Different test for 3.0+")
 def test_external_task_marker_exception(dag_bag_ext):
     """
     Clearing across multiple DAGs should raise AirflowException if more levels 
are being cleared
@@ -1463,13 +1725,17 @@ def dag_bag_cyclic():
             task_a >> task_b
 
         for dag in dags:
-            dag_bag.bag_dag(dag=dag)
+            if AIRFLOW_V_3_0_PLUS:
+                dag_bag.bag_dag(dag=dag)
+            else:
+                dag_bag.bag_dag(dag=dag, root_dag=dag)  # type: 
ignore[call-arg]
 
         return dag_bag
 
     return _factory
 
 
[email protected](AIRFLOW_V_3_0_PLUS, reason="Different test for 3.0+")
 def test_external_task_marker_cyclic_deep(dag_bag_cyclic):
     """
     Tests clearing across multiple DAGs that have cyclic dependencies. 
AirflowException should be
@@ -1483,6 +1749,7 @@ def test_external_task_marker_cyclic_deep(dag_bag_cyclic):
         clear_tasks(dag_bag, dag_0, task_a_0)
 
 
[email protected](AIRFLOW_V_3_0_PLUS, reason="Different test for 3.0+")
 def test_external_task_marker_cyclic_shallow(dag_bag_cyclic):
     """
     Tests clearing across multiple DAGs that have cyclic dependencies shallower
@@ -1513,8 +1780,13 @@ def dag_bag_multiple():
     dag_bag = DagBag(dag_folder=DEV_NULL, include_examples=False)
     daily_dag = DAG("daily_dag", start_date=DEFAULT_DATE, schedule="@daily")
     agg_dag = DAG("agg_dag", start_date=DEFAULT_DATE, schedule="@daily")
-    dag_bag.bag_dag(dag=daily_dag)
-    dag_bag.bag_dag(dag=agg_dag)
+
+    if AIRFLOW_V_3_0_PLUS:
+        dag_bag.bag_dag(dag=daily_dag)
+        dag_bag.bag_dag(dag=agg_dag)
+    else:
+        dag_bag.bag_dag(dag=daily_dag, root_dag=daily_dag)
+        dag_bag.bag_dag(dag=agg_dag, root_dag=agg_dag)
 
     daily_task = EmptyOperator(task_id="daily_tas", dag=daily_dag)
 
@@ -1584,7 +1856,10 @@ def dag_bag_head_tail():
         )
         head >> body >> tail
 
-    dag_bag.bag_dag(dag=dag)
+    if AIRFLOW_V_3_0_PLUS:
+        dag_bag.bag_dag(dag=dag)
+    else:
+        dag_bag.bag_dag(dag=dag, root_dag=dag)
 
     return dag_bag
 
@@ -1600,10 +1875,13 @@ def 
test_clear_overlapping_external_task_marker(dag_bag_head_tail, session):
             dag_id=dag.dag_id,
             start_date=logical_date,
             state=DagRunState.SUCCESS,
-            logical_date=logical_date,
             run_type=DagRunType.MANUAL,
             run_id=f"test_{delta}",
         )
+        if AIRFLOW_V_3_0_PLUS:
+            dagrun.logical_date = logical_date
+        else:
+            dagrun.execution_date = logical_date
         session.add(dagrun)
         for task in dag.tasks:
             ti = TaskInstance(task=task)
@@ -1625,10 +1903,13 @@ def 
test_clear_overlapping_external_task_marker_with_end_date(dag_bag_head_tail,
             dag_id=dag.dag_id,
             start_date=logical_date,
             state=DagRunState.SUCCESS,
-            logical_date=logical_date,
             run_type=DagRunType.MANUAL,
             run_id=f"test_{delta}",
         )
+        if AIRFLOW_V_3_0_PLUS:
+            dagrun.logical_date = logical_date
+        else:
+            dagrun.execution_date = logical_date
         session.add(dagrun)
         for task in dag.tasks:
             ti = TaskInstance(task=task)
@@ -1689,7 +1970,10 @@ def dag_bag_head_tail_mapped_tasks():
         )
         head >> body >> tail
 
-    dag_bag.bag_dag(dag=dag)
+    if AIRFLOW_V_3_0_PLUS:
+        dag_bag.bag_dag(dag=dag)
+    else:
+        dag_bag.bag_dag(dag=dag, root_dag=dag)
 
     return dag_bag
 
@@ -1705,10 +1989,13 @@ def 
test_clear_overlapping_external_task_marker_mapped_tasks(dag_bag_head_tail_m
             dag_id=dag.dag_id,
             start_date=logical_date,
             state=DagRunState.SUCCESS,
-            logical_date=logical_date,
             run_type=DagRunType.MANUAL,
             run_id=f"test_{delta}",
         )
+        if AIRFLOW_V_3_0_PLUS:
+            dagrun.logical_date = logical_date
+        else:
+            dagrun.execution_date = logical_date
         session.add(dagrun)
         for task in dag.tasks:
             if task.task_id == "dummy_task":
@@ -1721,12 +2008,19 @@ def 
test_clear_overlapping_external_task_marker_mapped_tasks(dag_bag_head_tail_m
                 ti.state = TaskInstanceState.SUCCESS
                 dagrun.task_instances.append(ti)
     session.flush()
+    if AIRFLOW_V_3_0_PLUS:
+        dag = dag.partial_subset(
+            task_ids=["head"],
+            include_downstream=True,
+            include_upstream=False,
+        )
+    else:
+        dag = dag.partial_subset(
+            task_ids_or_regex=["head"],
+            include_downstream=True,
+            include_upstream=False,
+        )
 
-    dag = dag.partial_subset(
-        task_ids=["head"],
-        include_downstream=True,
-        include_upstream=False,
-    )
     task_ids = list(dag.task_dict)
     assert (
         dag.clear(
diff --git a/task-sdk/src/airflow/sdk/api/client.py 
b/task-sdk/src/airflow/sdk/api/client.py
index cfffd2b3823..a2fff3a335b 100644
--- a/task-sdk/src/airflow/sdk/api/client.py
+++ b/task-sdk/src/airflow/sdk/api/client.py
@@ -59,10 +59,12 @@ from airflow.sdk.api.datamodels._generated import (
 )
 from airflow.sdk.exceptions import ErrorType
 from airflow.sdk.execution_time.comms import (
+    DRCount,
     ErrorResponse,
     OKResponse,
     SkipDownstreamTasks,
     TaskRescheduleStartDate,
+    TICount,
 )
 from airflow.utils.net import get_hostname
 from airflow.utils.platform import getuser
@@ -200,6 +202,31 @@ class TaskInstanceOperations:
         resp = self.client.get(f"task-reschedules/{id}/start_date", 
params={"try_number": try_number})
         return TaskRescheduleStartDate.model_construct(start_date=resp.json())
 
+    def get_count(
+        self,
+        dag_id: str,
+        task_ids: list[str] | None = None,
+        task_group_id: str | None = None,
+        logical_dates: list[datetime] | None = None,
+        run_ids: list[str] | None = None,
+        states: list[str] | None = None,
+    ) -> TICount:
+        """Get count of task instances matching the given criteria."""
+        params = {
+            "dag_id": dag_id,
+            "task_ids": task_ids,
+            "task_group_id": task_group_id,
+            "logical_dates": [d.isoformat() for d in logical_dates] if 
logical_dates is not None else None,
+            "run_ids": run_ids,
+            "states": states,
+        }
+
+        # Remove None values from params
+        params = {k: v for k, v in params.items() if v is not None}
+
+        resp = self.client.get("task-instances/count", params=params)
+        return TICount(count=resp.json())
+
 
 class ConnectionOperations:
     __slots__ = ("client",)
@@ -452,6 +479,27 @@ class DagRunOperations:
         resp = self.client.get(f"dag-runs/{dag_id}/{run_id}/state")
         return DagRunStateResponse.model_validate_json(resp.read())
 
+    def get_count(
+        self,
+        dag_id: str,
+        logical_dates: list[datetime] | None = None,
+        run_ids: list[str] | None = None,
+        states: list[str] | None = None,
+    ) -> DRCount:
+        """Get count of DAG runs matching the given criteria."""
+        params = {
+            "dag_id": dag_id,
+            "logical_dates": [d.isoformat() for d in logical_dates] if 
logical_dates is not None else None,
+            "run_ids": run_ids,
+            "states": states,
+        }
+
+        # Remove None values from params
+        params = {k: v for k, v in params.items() if v is not None}
+
+        resp = self.client.get("dag-runs/count", params=params)
+        return DRCount(count=resp.json())
+
 
 class BearerAuth(httpx.Auth):
     def __init__(self, token: str):
diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py 
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index d579aa8fb48..76c858f3ac1 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -222,6 +222,20 @@ class TaskRescheduleStartDate(BaseModel):
     type: Literal["TaskRescheduleStartDate"] = "TaskRescheduleStartDate"
 
 
+class TICount(BaseModel):
+    """Response containing count of Task Instances matching certain filters."""
+
+    count: int
+    type: Literal["TICount"] = "TICount"
+
+
+class DRCount(BaseModel):
+    """Response containing count of DAG Runs matching certain filters."""
+
+    count: int
+    type: Literal["DRCount"] = "DRCount"
+
+
 class ErrorResponse(BaseModel):
     error: ErrorType = ErrorType.GENERIC_ERROR
     detail: dict | None = None
@@ -239,10 +253,12 @@ ToTask = Annotated[
         AssetEventsResult,
         ConnectionResult,
         DagRunStateResult,
+        DRCount,
         ErrorResponse,
         PrevSuccessfulDagRunResult,
         StartupDetails,
         TaskRescheduleStartDate,
+        TICount,
         VariableResult,
         XComResult,
         XComCountResponse,
@@ -445,30 +461,50 @@ class GetTaskRescheduleStartDate(BaseModel):
     type: Literal["GetTaskRescheduleStartDate"] = "GetTaskRescheduleStartDate"
 
 
+class GetTICount(BaseModel):
+    dag_id: str
+    task_ids: list[str] | None = None
+    task_group_id: str | None = None
+    logical_dates: list[AwareDatetime] | None = None
+    run_ids: list[str] | None = None
+    states: list[str] | None = None
+    type: Literal["GetTICount"] = "GetTICount"
+
+
+class GetDRCount(BaseModel):
+    dag_id: str
+    logical_dates: list[AwareDatetime] | None = None
+    run_ids: list[str] | None = None
+    states: list[str] | None = None
+    type: Literal["GetDRCount"] = "GetDRCount"
+
+
 ToSupervisor = Annotated[
     Union[
-        SucceedTask,
         DeferTask,
+        DeleteXCom,
         GetAssetByName,
         GetAssetByUri,
         GetAssetEventByAsset,
         GetAssetEventByAssetAlias,
         GetConnection,
         GetDagRunState,
+        GetDRCount,
         GetPrevSuccessfulDagRun,
         GetTaskRescheduleStartDate,
+        GetTICount,
         GetVariable,
         GetXCom,
         GetXComCount,
         PutVariable,
         RescheduleTask,
         RetryTask,
-        SkipDownstreamTasks,
         SetRenderedFields,
         SetXCom,
+        SkipDownstreamTasks,
+        SucceedTask,
         TaskState,
         TriggerDagRun,
-        DeleteXCom,
     ],
     Field(discriminator="type"),
 ]
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index a353ffce23c..21ae3cd4cae 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -77,8 +77,10 @@ from airflow.sdk.execution_time.comms import (
     GetAssetEventByAssetAlias,
     GetConnection,
     GetDagRunState,
+    GetDRCount,
     GetPrevSuccessfulDagRun,
     GetTaskRescheduleStartDate,
+    GetTICount,
     GetVariable,
     GetXCom,
     GetXComCount,
@@ -988,6 +990,24 @@ class ActivitySubprocess(WatchedSubprocess):
         elif isinstance(msg, GetTaskRescheduleStartDate):
             tr_resp = 
self.client.task_instances.get_reschedule_start_date(msg.ti_id, msg.try_number)
             resp = tr_resp.model_dump_json().encode()
+        elif isinstance(msg, GetTICount):
+            ti_count = self.client.task_instances.get_count(
+                dag_id=msg.dag_id,
+                task_ids=msg.task_ids,
+                task_group_id=msg.task_group_id,
+                logical_dates=msg.logical_dates,
+                run_ids=msg.run_ids,
+                states=msg.states,
+            )
+            resp = ti_count.model_dump_json().encode()
+        elif isinstance(msg, GetDRCount):
+            dr_count = self.client.dag_runs.get_count(
+                dag_id=msg.dag_id,
+                logical_dates=msg.logical_dates,
+                run_ids=msg.run_ids,
+                states=msg.states,
+            )
+            resp = dr_count.model_dump_json().encode()
         else:
             log.error("Unhandled request", msg=msg)
             return
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index d5d738f5254..d89fbeb0aa9 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -58,9 +58,12 @@ from airflow.sdk.execution_time.callback_runner import 
create_executable_runner
 from airflow.sdk.execution_time.comms import (
     DagRunStateResult,
     DeferTask,
+    DRCount,
     ErrorResponse,
     GetDagRunState,
+    GetDRCount,
     GetTaskRescheduleStartDate,
+    GetTICount,
     RescheduleTask,
     RetryTask,
     SetRenderedFields,
@@ -69,6 +72,7 @@ from airflow.sdk.execution_time.comms import (
     SucceedTask,
     TaskRescheduleStartDate,
     TaskState,
+    TICount,
     ToSupervisor,
     ToTask,
     TriggerDagRun,
@@ -400,6 +404,62 @@ class RuntimeTaskInstance(TaskInstance):
 
         return response.start_date
 
+    @staticmethod
+    def get_ti_count(
+        dag_id: str,
+        task_ids: list[str] | None = None,
+        task_group_id: str | None = None,
+        logical_dates: list[datetime] | None = None,
+        run_ids: list[str] | None = None,
+        states: list[str] | None = None,
+    ) -> int:
+        """Return the number of task instances matching the given criteria."""
+        log = structlog.get_logger(logger_name="task")
+
+        SUPERVISOR_COMMS.send_request(
+            log=log,
+            msg=GetTICount(
+                dag_id=dag_id,
+                task_ids=task_ids,
+                task_group_id=task_group_id,
+                logical_dates=logical_dates,
+                run_ids=run_ids,
+                states=states,
+            ),
+        )
+        response = SUPERVISOR_COMMS.get_message()
+
+        if TYPE_CHECKING:
+            assert isinstance(response, TICount)
+
+        return response.count
+
+    @staticmethod
+    def get_dr_count(
+        dag_id: str,
+        logical_dates: list[datetime] | None = None,
+        run_ids: list[str] | None = None,
+        states: list[str] | None = None,
+    ) -> int:
+        """Return the number of DAG runs matching the given criteria."""
+        log = structlog.get_logger(logger_name="task")
+
+        SUPERVISOR_COMMS.send_request(
+            log=log,
+            msg=GetDRCount(
+                dag_id=dag_id,
+                logical_dates=logical_dates,
+                run_ids=run_ids,
+                states=states,
+            ),
+        )
+        response = SUPERVISOR_COMMS.get_message()
+
+        if TYPE_CHECKING:
+            assert isinstance(response, DRCount)
+
+        return response.count
+
 
 def _xcom_push(ti: RuntimeTaskInstance, key: str, value: Any, mapped_length: 
int | None = None) -> None:
     """Push a XCom through XCom.set, which pushes to XCom Backend if 
configured."""
diff --git a/task-sdk/src/airflow/sdk/types.py 
b/task-sdk/src/airflow/sdk/types.py
index 6760ea51959..d9544589cf9 100644
--- a/task-sdk/src/airflow/sdk/types.py
+++ b/task-sdk/src/airflow/sdk/types.py
@@ -87,6 +87,24 @@ class RuntimeTaskInstanceProtocol(Protocol):
 
     def get_first_reschedule_date(self, first_try_number) -> AwareDatetime | 
None: ...
 
+    @staticmethod
+    def get_ti_count(
+        dag_id: str,
+        task_ids: list[str] | None = None,
+        task_group_id: str | None = None,
+        logical_dates: list[AwareDatetime] | None = None,
+        run_ids: list[str] | None = None,
+        states: list[str] | None = None,
+    ) -> int: ...
+
+    @staticmethod
+    def get_dr_count(
+        dag_id: str,
+        logical_dates: list[AwareDatetime] | None = None,
+        run_ids: list[str] | None = None,
+        states: list[str] | None = None,
+    ) -> int: ...
+
 
 class OutletEventAccessorProtocol(Protocol, attrs.AttrsInstance):
     """Protocol for managing access to a specific outlet event accessor."""
diff --git a/task-sdk/tests/task_sdk/api/test_client.py 
b/task-sdk/tests/task_sdk/api/test_client.py
index f0092f8aea9..3339a6cee4c 100644
--- a/task-sdk/tests/task_sdk/api/test_client.py
+++ b/task-sdk/tests/task_sdk/api/test_client.py
@@ -405,6 +405,48 @@ class TestTaskInstanceOperations:
 
         assert result == {"ok": True}
 
+    def test_get_count_basic(self):
+        """Test basic get_count functionality with just dag_id."""
+
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            assert request.url.path == "/task-instances/count"
+            assert request.url.params.get("dag_id") == "test_dag"
+            return httpx.Response(200, json=5)
+
+        client = make_client(transport=httpx.MockTransport(handle_request))
+        result = client.task_instances.get_count(dag_id="test_dag")
+        assert result.count == 5
+
+    def test_get_count_with_all_params(self):
+        """Test get_count with all optional parameters."""
+
+        logical_dates_str = ["2024-01-01T00:00:00+00:00", 
"2024-01-02T00:00:00+00:00"]
+        logical_dates = [timezone.parse(d) for d in logical_dates_str]
+        task_ids = ["task1", "task2"]
+        states = ["success", "failed"]
+
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            assert request.url.path == "/task-instances/count"
+            assert request.method == "GET"
+            params = request.url.params
+            assert params["dag_id"] == "test_dag"
+            assert params.get_list("task_ids") == task_ids
+            assert params["task_group_id"] == "group1"
+            assert params.get_list("logical_dates") == logical_dates_str
+            assert params.get_list("run_ids") == []
+            assert params.get_list("states") == states
+            return httpx.Response(200, json=10)
+
+        client = make_client(transport=httpx.MockTransport(handle_request))
+        result = client.task_instances.get_count(
+            dag_id="test_dag",
+            task_ids=task_ids,
+            task_group_id="group1",
+            logical_dates=logical_dates,
+            states=states,
+        )
+        assert result.count == 10
+
 
 class TestVariableOperations:
     """
@@ -904,6 +946,58 @@ class TestDagRunOperations:
 
         assert result == DagRunStateResponse(state=DagRunState.RUNNING)
 
+    def test_get_count_basic(self):
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            if request.url.path == "/dag-runs/count":
+                assert request.url.params["dag_id"] == "test_dag"
+                return httpx.Response(status_code=200, json=1)
+            return httpx.Response(status_code=422)
+
+        client = make_client(transport=httpx.MockTransport(handle_request))
+        result = client.dag_runs.get_count(dag_id="test_dag")
+        assert result.count == 1
+
+    def test_get_count_with_states(self):
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            if request.url.path == "/dag-runs/count":
+                assert request.url.params["dag_id"] == "test_dag"
+                assert request.url.params.get_list("states") == ["success", 
"failed"]
+                return httpx.Response(status_code=200, json=2)
+            return httpx.Response(status_code=422)
+
+        client = make_client(transport=httpx.MockTransport(handle_request))
+        result = client.dag_runs.get_count(dag_id="test_dag", 
states=["success", "failed"])
+        assert result.count == 2
+
+    def test_get_count_with_logical_dates(self):
+        logical_dates = [timezone.datetime(2025, 1, 1), 
timezone.datetime(2025, 1, 2)]
+        logical_dates_str = [d.isoformat() for d in logical_dates]
+
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            if request.url.path == "/dag-runs/count":
+                assert request.url.params["dag_id"] == "test_dag"
+                assert request.url.params.get_list("logical_dates") == 
logical_dates_str
+                return httpx.Response(status_code=200, json=2)
+            return httpx.Response(status_code=422)
+
+        client = make_client(transport=httpx.MockTransport(handle_request))
+        result = client.dag_runs.get_count(
+            dag_id="test_dag", logical_dates=[timezone.datetime(2025, 1, 1), 
timezone.datetime(2025, 1, 2)]
+        )
+        assert result.count == 2
+
+    def test_get_count_with_run_ids(self):
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            if request.url.path == "/dag-runs/count":
+                assert request.url.params["dag_id"] == "test_dag"
+                assert request.url.params.get_list("run_ids") == ["run1", 
"run2"]
+                return httpx.Response(status_code=200, json=2)
+            return httpx.Response(status_code=422)
+
+        client = make_client(transport=httpx.MockTransport(handle_request))
+        result = client.dag_runs.get_count(dag_id="test_dag", run_ids=["run1", 
"run2"])
+        assert result.count == 2
+
 
 class TestTaskRescheduleOperations:
     def test_get_start_date(self):
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py 
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index 64906726219..5cae5232e64 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -60,6 +60,7 @@ from airflow.sdk.execution_time.comms import (
     DagRunStateResult,
     DeferTask,
     DeleteXCom,
+    DRCount,
     ErrorResponse,
     GetAssetByName,
     GetAssetByUri,
@@ -67,8 +68,10 @@ from airflow.sdk.execution_time.comms import (
     GetAssetEventByAssetAlias,
     GetConnection,
     GetDagRunState,
+    GetDRCount,
     GetPrevSuccessfulDagRun,
     GetTaskRescheduleStartDate,
+    GetTICount,
     GetVariable,
     GetXCom,
     OKResponse,
@@ -80,6 +83,7 @@ from airflow.sdk.execution_time.comms import (
     SucceedTask,
     TaskRescheduleStartDate,
     TaskState,
+    TICount,
     TriggerDagRun,
     VariableResult,
     XComResult,
@@ -1350,6 +1354,36 @@ class TestHandleRequest:
                 
TaskRescheduleStartDate(start_date=timezone.parse("2024-10-31T12:00:00Z")),
                 id="get_task_reschedule_start_date",
             ),
+            pytest.param(
+                GetTICount(dag_id="test_dag", task_ids=["task1", "task2"]),
+                b'{"count":2,"type":"TICount"}\n',
+                "task_instances.get_count",
+                (),
+                {
+                    "dag_id": "test_dag",
+                    "logical_dates": None,
+                    "run_ids": None,
+                    "states": None,
+                    "task_group_id": None,
+                    "task_ids": ["task1", "task2"],
+                },
+                TICount(count=2),
+                id="get_ti_count",
+            ),
+            pytest.param(
+                GetDRCount(dag_id="test_dag", states=["success", "failed"]),
+                b'{"count":2,"type":"DRCount"}\n',
+                "dag_runs.get_count",
+                (),
+                {
+                    "dag_id": "test_dag",
+                    "logical_dates": None,
+                    "run_ids": None,
+                    "states": ["success", "failed"],
+                },
+                DRCount(count=2),
+                id="get_dr_count",
+            ),
         ],
     )
     def test_handle_requests(
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py 
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 150b691fc17..f7ac25f6ae8 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -67,9 +67,12 @@ from airflow.sdk.execution_time.comms import (
     ConnectionResult,
     DagRunStateResult,
     DeferTask,
+    DRCount,
     ErrorResponse,
     GetConnection,
     GetDagRunState,
+    GetDRCount,
+    GetTICount,
     GetVariable,
     GetXCom,
     OKResponse,
@@ -81,6 +84,7 @@ from airflow.sdk.execution_time.comms import (
     SucceedTask,
     TaskRescheduleStartDate,
     TaskState,
+    TICount,
     TriggerDagRun,
     VariableResult,
     XComResult,
@@ -1396,6 +1400,54 @@ class TestRuntimeTaskInstance:
         context = runtime_ti.get_template_context()
         assert runtime_ti.get_first_reschedule_date(context=context) == 
expected_date
 
+    def test_get_ti_count(self, mock_supervisor_comms):
+        """Test that get_ti_count sends the correct request and returns the 
count."""
+        mock_supervisor_comms.get_message.return_value = TICount(count=2)
+
+        count = RuntimeTaskInstance.get_ti_count(
+            dag_id="test_dag",
+            task_ids=["task1", "task2"],
+            task_group_id="group1",
+            logical_dates=[timezone.datetime(2024, 1, 1)],
+            run_ids=["run1"],
+            states=["success", "failed"],
+        )
+
+        mock_supervisor_comms.send_request.assert_called_once_with(
+            log=mock.ANY,
+            msg=GetTICount(
+                dag_id="test_dag",
+                task_ids=["task1", "task2"],
+                task_group_id="group1",
+                logical_dates=[timezone.datetime(2024, 1, 1)],
+                run_ids=["run1"],
+                states=["success", "failed"],
+            ),
+        )
+        assert count == 2
+
+    def test_get_dr_count(self, mock_supervisor_comms):
+        """Test that get_dr_count sends the correct request and returns the 
count."""
+        mock_supervisor_comms.get_message.return_value = DRCount(count=2)
+
+        count = RuntimeTaskInstance.get_dr_count(
+            dag_id="test_dag",
+            logical_dates=[timezone.datetime(2024, 1, 1)],
+            run_ids=["run1"],
+            states=["success", "failed"],
+        )
+
+        mock_supervisor_comms.send_request.assert_called_once_with(
+            log=mock.ANY,
+            msg=GetDRCount(
+                dag_id="test_dag",
+                logical_dates=[timezone.datetime(2024, 1, 1)],
+                run_ids=["run1"],
+                states=["success", "failed"],
+            ),
+        )
+        assert count == 2
+
 
 class TestXComAfterTaskExecution:
     @pytest.mark.parametrize(

Reply via email to