This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new ddec35dc6b Removed deprecated TaskStateTrigger from 
airflow.triggers.external_task module (#41737)
ddec35dc6b is described below

commit ddec35dc6b9aa8ef1b93a1aa0e62861fc2d599c5
Author: Jens Scheffler <[email protected]>
AuthorDate: Tue Aug 27 15:55:00 2024 +0200

    Removed deprecated TaskStateTrigger from airflow.triggers.external_task 
module (#41737)
---
 airflow/triggers/external_task.py    | 121 +--------------------
 newsfragments/41737.significant.rst  |   1 +
 tests/triggers/test_external_task.py | 201 +----------------------------------
 3 files changed, 4 insertions(+), 319 deletions(-)

diff --git a/airflow/triggers/external_task.py 
b/airflow/triggers/external_task.py
index a5de817c35..cd43d59876 100644
--- a/airflow/triggers/external_task.py
+++ b/airflow/triggers/external_task.py
@@ -21,16 +21,12 @@ import typing
 from typing import Any
 
 from asgiref.sync import sync_to_async
-from deprecated import deprecated
 from sqlalchemy import func
 
-from airflow.exceptions import RemovedInAirflow3Warning
-from airflow.models import DagRun, TaskInstance
+from airflow.models import DagRun
 from airflow.triggers.base import BaseTrigger, TriggerEvent
 from airflow.utils.sensor_helper import _get_count
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.state import TaskInstanceState
-from airflow.utils.timezone import utcnow
 
 if typing.TYPE_CHECKING:
     from datetime import datetime
@@ -136,121 +132,6 @@ class WorkflowTrigger(BaseTrigger):
         )
 
 
-@deprecated(
-    reason="TaskStateTrigger has been deprecated and will be removed in 
future.",
-    category=RemovedInAirflow3Warning,
-)
-class TaskStateTrigger(BaseTrigger):
-    """
-    Waits asynchronously for a task in a different DAG to complete for a 
specific logical date.
-
-    :param dag_id: The dag_id that contains the task you want to wait for
-    :param task_id: The task_id that contains the task you want to
-        wait for.
-    :param states: allowed states, default is ``['success']``
-    :param execution_dates: task execution time interval
-    :param poll_interval: The time interval in seconds to check the state.
-        The default value is 5 sec.
-    :param trigger_start_time: time in Datetime format when the trigger was 
started. Is used
-        to control the execution of trigger to prevent infinite loop in case 
if specified name
-        of the dag does not exist in database. It will wait period of time 
equals _timeout_sec parameter
-        from the time, when the trigger was started and if the execution lasts 
more time than expected,
-        the trigger will terminate with 'timeout' status.
-    """
-
-    def __init__(
-        self,
-        dag_id: str,
-        execution_dates: list[datetime],
-        trigger_start_time: datetime,
-        states: list[str] | None = None,
-        task_id: str | None = None,
-        poll_interval: float = 2.0,
-    ):
-        super().__init__()
-        self.dag_id = dag_id
-        self.task_id = task_id
-        self.states = states
-        self.execution_dates = execution_dates
-        self.poll_interval = poll_interval
-        self.trigger_start_time = trigger_start_time
-        self.states = states or [TaskInstanceState.SUCCESS.value]
-        self._timeout_sec = 60
-
-    def serialize(self) -> tuple[str, dict[str, typing.Any]]:
-        """Serialize TaskStateTrigger arguments and classpath."""
-        return (
-            "airflow.triggers.external_task.TaskStateTrigger",
-            {
-                "dag_id": self.dag_id,
-                "task_id": self.task_id,
-                "states": self.states,
-                "execution_dates": self.execution_dates,
-                "poll_interval": self.poll_interval,
-                "trigger_start_time": self.trigger_start_time,
-            },
-        )
-
-    async def run(self) -> typing.AsyncIterator[TriggerEvent]:
-        """
-        Check periodically in the database to see if the dag exists and is in 
the running state.
-
-        If found, wait until the task specified will reach one of the expected 
states.
-        If dag with specified name was not in the running state after 
_timeout_sec seconds
-        after starting execution process of the trigger, terminate with status 
'timeout'.
-        """
-        try:
-            while True:
-                delta = utcnow() - self.trigger_start_time
-                if delta.total_seconds() < self._timeout_sec:
-                    # mypy confuses typing here
-                    if await self.count_running_dags() == 0:  # type: 
ignore[call-arg]
-                        self.log.info("Waiting for DAG to start execution...")
-                        await asyncio.sleep(self.poll_interval)
-                else:
-                    yield TriggerEvent({"status": "timeout"})
-                    return
-                # mypy confuses typing here
-                if await self.count_tasks() == len(self.execution_dates):  # 
type: ignore[call-arg]
-                    yield TriggerEvent({"status": "success"})
-                    return
-                self.log.info("Task is still running, sleeping for %s 
seconds...", self.poll_interval)
-                await asyncio.sleep(self.poll_interval)
-        except Exception:
-            yield TriggerEvent({"status": "failed"})
-
-    @sync_to_async
-    @provide_session
-    def count_running_dags(self, session: Session):
-        """Count how many dag instances in running state in the database."""
-        dags = (
-            session.query(func.count("*"))
-            .filter(
-                TaskInstance.dag_id == self.dag_id,
-                TaskInstance.execution_date.in_(self.execution_dates),
-                TaskInstance.state.in_(["running", "success"]),
-            )
-            .scalar()
-        )
-        return dags
-
-    @sync_to_async
-    @provide_session
-    def count_tasks(self, *, session: Session = NEW_SESSION) -> int | None:
-        """Count how many task instances in the database match our criteria."""
-        count = (
-            session.query(func.count("*"))  # .count() is inefficient
-            .filter(
-                TaskInstance.dag_id == self.dag_id,
-                TaskInstance.task_id == self.task_id,
-                TaskInstance.state.in_(self.states),
-                TaskInstance.execution_date.in_(self.execution_dates),
-            )
-            .scalar()
-        )
-        return typing.cast(int, count)
-
-
 class DagStateTrigger(BaseTrigger):
     """
     Waits asynchronously for a DAG to complete for a specific logical date.
diff --git a/newsfragments/41737.significant.rst 
b/newsfragments/41737.significant.rst
new file mode 100644
index 0000000000..55704581be
--- /dev/null
+++ b/newsfragments/41737.significant.rst
@@ -0,0 +1 @@
+Removed deprecated ``TaskStateTrigger`` from 
``airflow.triggers.external_task`` module.
diff --git a/tests/triggers/test_external_task.py 
b/tests/triggers/test_external_task.py
index 7bb41c3450..ced867c4bd 100644
--- a/tests/triggers/test_external_task.py
+++ b/tests/triggers/test_external_task.py
@@ -17,23 +17,17 @@
 from __future__ import annotations
 
 import asyncio
-import datetime
 import time
 from unittest import mock
 
 import pytest
-from sqlalchemy.exc import SQLAlchemyError
 
-from airflow.exceptions import RemovedInAirflow3Warning
 from airflow.models.dag import DAG
 from airflow.models.dagrun import DagRun
-from airflow.models.taskinstance import TaskInstance
-from airflow.operators.empty import EmptyOperator
 from airflow.triggers.base import TriggerEvent
-from airflow.triggers.external_task import DagStateTrigger, TaskStateTrigger, 
WorkflowTrigger
+from airflow.triggers.external_task import DagStateTrigger, WorkflowTrigger
 from airflow.utils import timezone
-from airflow.utils.state import DagRunState, TaskInstanceState
-from airflow.utils.timezone import utcnow
+from airflow.utils.state import DagRunState
 
 
 class TestWorkflowTrigger:
@@ -222,197 +216,6 @@ class TestWorkflowTrigger:
         }
 
 
-class TestTaskStateTrigger:
-    DAG_ID = "external_task"
-    TASK_ID = "external_task_op"
-    RUN_ID = "external_task_run_id"
-    STATES = ["success", "fail"]
-
-    @pytest.mark.skip_if_database_isolation_mode  # Test is broken in db 
isolation mode
-    @pytest.mark.db_test
-    @pytest.mark.asyncio
-    async def test_task_state_trigger_success(self, session):
-        """
-        Asserts that the TaskStateTrigger only goes off on or after a 
TaskInstance
-        reaches an allowed state (i.e. SUCCESS).
-        """
-        trigger_start_time = utcnow()
-        dag = DAG(self.DAG_ID, schedule=None, 
start_date=timezone.datetime(2022, 1, 1))
-        dag_run = DagRun(
-            dag_id=dag.dag_id,
-            run_type="manual",
-            execution_date=timezone.datetime(2022, 1, 1),
-            run_id=self.RUN_ID,
-        )
-        session.add(dag_run)
-        session.commit()
-
-        external_task = EmptyOperator(task_id=self.TASK_ID, dag=dag)
-        instance = TaskInstance(external_task, run_id=self.RUN_ID)
-        session.add(instance)
-        session.commit()
-
-        with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger 
has been deprecated"):
-            trigger = TaskStateTrigger(
-                dag_id=dag.dag_id,
-                task_id=instance.task_id,
-                states=self.STATES,
-                execution_dates=[timezone.datetime(2022, 1, 1)],
-                poll_interval=0.2,
-                trigger_start_time=trigger_start_time,
-            )
-
-        task = asyncio.create_task(trigger.run().__anext__())
-        await asyncio.sleep(0.5)
-
-        # It should not have produced a result
-        assert task.done() is False
-
-        # Progress the task to a "success" state so that run() yields a 
TriggerEvent
-        instance.state = TaskInstanceState.SUCCESS
-        session.commit()
-        await asyncio.sleep(0.5)
-        assert task.done() is True
-
-        # Prevents error when task is destroyed while in "pending" state
-        asyncio.get_event_loop().stop()
-
-    @mock.patch("airflow.triggers.external_task.utcnow")
-    @pytest.mark.asyncio
-    async def test_task_state_trigger_timeout(self, mock_utcnow):
-        trigger_start_time = utcnow()
-        mock_utcnow.return_value = trigger_start_time + 
datetime.timedelta(seconds=61)
-
-        with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger 
has been deprecated"):
-            trigger = TaskStateTrigger(
-                dag_id="dag1",
-                task_id="task1",
-                states=self.STATES,
-                execution_dates=[timezone.datetime(2022, 1, 1)],
-                poll_interval=0.2,
-                trigger_start_time=trigger_start_time,
-            )
-
-        trigger.count_running_dags = mock.AsyncMock()
-        trigger.count_running_dags.return_value = 0
-
-        gen = trigger.run()
-        task = asyncio.create_task(gen.__anext__())
-        await task
-
-        result = task.result()
-        assert isinstance(result, TriggerEvent)
-        assert result.payload == {"status": "timeout"}
-        assert task.done() is True
-
-        # test that it returns after yielding
-        with pytest.raises(StopAsyncIteration):
-            await gen.__anext__()
-
-    @mock.patch("airflow.triggers.external_task.utcnow")
-    @mock.patch("airflow.triggers.external_task.asyncio.sleep")
-    @pytest.mark.asyncio
-    async def test_task_state_trigger_timeout_sleep_success(self, mock_sleep, 
mock_utcnow):
-        trigger_start_time = utcnow()
-        mock_utcnow.return_value = trigger_start_time + 
datetime.timedelta(seconds=20)
-
-        with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger 
has been deprecated"):
-            trigger = TaskStateTrigger(
-                dag_id="dag1",
-                task_id="task1",
-                states=self.STATES,
-                execution_dates=[timezone.datetime(2022, 1, 1)],
-                poll_interval=0.2,
-                trigger_start_time=trigger_start_time,
-            )
-
-        trigger.count_running_dags = mock.AsyncMock()
-        trigger.count_running_dags.return_value = 0
-
-        trigger.count_tasks = mock.AsyncMock()
-        trigger.count_tasks.return_value = 1
-
-        gen = trigger.run()
-        task = asyncio.create_task(gen.__anext__())
-        await task
-
-        mock_sleep.assert_awaited()
-        assert mock_sleep.await_count == 1
-
-        result = task.result()
-        assert isinstance(result, TriggerEvent)
-        assert result.payload == {"status": "success"}
-        assert task.done() is True
-
-        # test that it returns after yielding
-        with pytest.raises(StopAsyncIteration):
-            await gen.__anext__()
-
-    @mock.patch("airflow.triggers.external_task.utcnow")
-    @mock.patch("airflow.triggers.external_task.asyncio.sleep")
-    @pytest.mark.asyncio
-    async def test_task_state_trigger_failed_exception(self, mock_sleep, 
mock_utcnow):
-        """
-        Asserts that the TaskStateTrigger only goes off on or after a 
TaskInstance
-        reaches an allowed state (i.e. SUCCESS).
-        """
-        trigger_start_time = utcnow()
-        mock_utcnow.return_value = +datetime.timedelta(seconds=61)
-
-        mock_utcnow.side_effect = [
-            trigger_start_time,
-            trigger_start_time + datetime.timedelta(seconds=20),
-        ]
-
-        with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger 
has been deprecated"):
-            trigger = TaskStateTrigger(
-                dag_id="dag1",
-                task_id="task1",
-                states=self.STATES,
-                execution_dates=[timezone.datetime(2022, 1, 1)],
-                poll_interval=0.2,
-                trigger_start_time=trigger_start_time,
-            )
-
-        trigger.count_running_dags = mock.AsyncMock()
-        trigger.count_running_dags.side_effect = [SQLAlchemyError]
-
-        gen = trigger.run()
-        task = asyncio.create_task(gen.__anext__())
-        await task
-
-        result = task.result()
-        assert isinstance(result, TriggerEvent)
-        assert result.payload == {"status": "failed"}
-        assert task.done() is True
-
-    def test_serialization(self):
-        """
-        Asserts that the TaskStateTrigger correctly serializes its arguments
-        and classpath.
-        """
-        trigger_start_time = utcnow()
-        with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger 
has been deprecated"):
-            trigger = TaskStateTrigger(
-                dag_id=self.DAG_ID,
-                task_id=self.TASK_ID,
-                states=self.STATES,
-                execution_dates=[timezone.datetime(2022, 1, 1)],
-                poll_interval=5,
-                trigger_start_time=trigger_start_time,
-            )
-        classpath, kwargs = trigger.serialize()
-        assert classpath == "airflow.triggers.external_task.TaskStateTrigger"
-        assert kwargs == {
-            "dag_id": self.DAG_ID,
-            "task_id": self.TASK_ID,
-            "states": self.STATES,
-            "execution_dates": [timezone.datetime(2022, 1, 1)],
-            "poll_interval": 5,
-            "trigger_start_time": trigger_start_time,
-        }
-
-
 class TestDagStateTrigger:
     DAG_ID = "test_dag_state_trigger"
     RUN_ID = "external_task_run_id"

Reply via email to