This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new ddec35dc6b Removed deprecated TaskStateTrigger from
airflow.triggers.external_task module (#41737)
ddec35dc6b is described below
commit ddec35dc6b9aa8ef1b93a1aa0e62861fc2d599c5
Author: Jens Scheffler <[email protected]>
AuthorDate: Tue Aug 27 15:55:00 2024 +0200
Removed deprecated TaskStateTrigger from airflow.triggers.external_task
module (#41737)
---
airflow/triggers/external_task.py | 121 +--------------------
newsfragments/41737.significant.rst | 1 +
tests/triggers/test_external_task.py | 201 +----------------------------------
3 files changed, 4 insertions(+), 319 deletions(-)
diff --git a/airflow/triggers/external_task.py
b/airflow/triggers/external_task.py
index a5de817c35..cd43d59876 100644
--- a/airflow/triggers/external_task.py
+++ b/airflow/triggers/external_task.py
@@ -21,16 +21,12 @@ import typing
from typing import Any
from asgiref.sync import sync_to_async
-from deprecated import deprecated
from sqlalchemy import func
-from airflow.exceptions import RemovedInAirflow3Warning
-from airflow.models import DagRun, TaskInstance
+from airflow.models import DagRun
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.sensor_helper import _get_count
from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.state import TaskInstanceState
-from airflow.utils.timezone import utcnow
if typing.TYPE_CHECKING:
from datetime import datetime
@@ -136,121 +132,6 @@ class WorkflowTrigger(BaseTrigger):
)
-@deprecated(
- reason="TaskStateTrigger has been deprecated and will be removed in
future.",
- category=RemovedInAirflow3Warning,
-)
-class TaskStateTrigger(BaseTrigger):
- """
- Waits asynchronously for a task in a different DAG to complete for a
specific logical date.
-
- :param dag_id: The dag_id that contains the task you want to wait for
- :param task_id: The task_id that contains the task you want to
- wait for.
- :param states: allowed states, default is ``['success']``
- :param execution_dates: task execution time interval
- :param poll_interval: The time interval in seconds to check the state.
- The default value is 5 sec.
- :param trigger_start_time: time in Datetime format when the trigger was
started. Is used
- to control the execution of trigger to prevent infinite loop in case
if specified name
- of the dag does not exist in database. It will wait period of time
equals _timeout_sec parameter
- from the time, when the trigger was started and if the execution lasts
more time than expected,
- the trigger will terminate with 'timeout' status.
- """
-
- def __init__(
- self,
- dag_id: str,
- execution_dates: list[datetime],
- trigger_start_time: datetime,
- states: list[str] | None = None,
- task_id: str | None = None,
- poll_interval: float = 2.0,
- ):
- super().__init__()
- self.dag_id = dag_id
- self.task_id = task_id
- self.states = states
- self.execution_dates = execution_dates
- self.poll_interval = poll_interval
- self.trigger_start_time = trigger_start_time
- self.states = states or [TaskInstanceState.SUCCESS.value]
- self._timeout_sec = 60
-
- def serialize(self) -> tuple[str, dict[str, typing.Any]]:
- """Serialize TaskStateTrigger arguments and classpath."""
- return (
- "airflow.triggers.external_task.TaskStateTrigger",
- {
- "dag_id": self.dag_id,
- "task_id": self.task_id,
- "states": self.states,
- "execution_dates": self.execution_dates,
- "poll_interval": self.poll_interval,
- "trigger_start_time": self.trigger_start_time,
- },
- )
-
- async def run(self) -> typing.AsyncIterator[TriggerEvent]:
- """
- Check periodically in the database to see if the dag exists and is in
the running state.
-
- If found, wait until the task specified will reach one of the expected
states.
- If dag with specified name was not in the running state after
_timeout_sec seconds
- after starting execution process of the trigger, terminate with status
'timeout'.
- """
- try:
- while True:
- delta = utcnow() - self.trigger_start_time
- if delta.total_seconds() < self._timeout_sec:
- # mypy confuses typing here
- if await self.count_running_dags() == 0: # type:
ignore[call-arg]
- self.log.info("Waiting for DAG to start execution...")
- await asyncio.sleep(self.poll_interval)
- else:
- yield TriggerEvent({"status": "timeout"})
- return
- # mypy confuses typing here
- if await self.count_tasks() == len(self.execution_dates): #
type: ignore[call-arg]
- yield TriggerEvent({"status": "success"})
- return
- self.log.info("Task is still running, sleeping for %s
seconds...", self.poll_interval)
- await asyncio.sleep(self.poll_interval)
- except Exception:
- yield TriggerEvent({"status": "failed"})
-
- @sync_to_async
- @provide_session
- def count_running_dags(self, session: Session):
- """Count how many dag instances in running state in the database."""
- dags = (
- session.query(func.count("*"))
- .filter(
- TaskInstance.dag_id == self.dag_id,
- TaskInstance.execution_date.in_(self.execution_dates),
- TaskInstance.state.in_(["running", "success"]),
- )
- .scalar()
- )
- return dags
-
- @sync_to_async
- @provide_session
- def count_tasks(self, *, session: Session = NEW_SESSION) -> int | None:
- """Count how many task instances in the database match our criteria."""
- count = (
- session.query(func.count("*")) # .count() is inefficient
- .filter(
- TaskInstance.dag_id == self.dag_id,
- TaskInstance.task_id == self.task_id,
- TaskInstance.state.in_(self.states),
- TaskInstance.execution_date.in_(self.execution_dates),
- )
- .scalar()
- )
- return typing.cast(int, count)
-
-
class DagStateTrigger(BaseTrigger):
"""
Waits asynchronously for a DAG to complete for a specific logical date.
diff --git a/newsfragments/41737.significant.rst
b/newsfragments/41737.significant.rst
new file mode 100644
index 0000000000..55704581be
--- /dev/null
+++ b/newsfragments/41737.significant.rst
@@ -0,0 +1 @@
+Removed deprecated ``TaskStateTrigger`` from
``airflow.triggers.external_task`` module.
diff --git a/tests/triggers/test_external_task.py
b/tests/triggers/test_external_task.py
index 7bb41c3450..ced867c4bd 100644
--- a/tests/triggers/test_external_task.py
+++ b/tests/triggers/test_external_task.py
@@ -17,23 +17,17 @@
from __future__ import annotations
import asyncio
-import datetime
import time
from unittest import mock
import pytest
-from sqlalchemy.exc import SQLAlchemyError
-from airflow.exceptions import RemovedInAirflow3Warning
from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
-from airflow.models.taskinstance import TaskInstance
-from airflow.operators.empty import EmptyOperator
from airflow.triggers.base import TriggerEvent
-from airflow.triggers.external_task import DagStateTrigger, TaskStateTrigger,
WorkflowTrigger
+from airflow.triggers.external_task import DagStateTrigger, WorkflowTrigger
from airflow.utils import timezone
-from airflow.utils.state import DagRunState, TaskInstanceState
-from airflow.utils.timezone import utcnow
+from airflow.utils.state import DagRunState
class TestWorkflowTrigger:
@@ -222,197 +216,6 @@ class TestWorkflowTrigger:
}
-class TestTaskStateTrigger:
- DAG_ID = "external_task"
- TASK_ID = "external_task_op"
- RUN_ID = "external_task_run_id"
- STATES = ["success", "fail"]
-
- @pytest.mark.skip_if_database_isolation_mode # Test is broken in db
isolation mode
- @pytest.mark.db_test
- @pytest.mark.asyncio
- async def test_task_state_trigger_success(self, session):
- """
- Asserts that the TaskStateTrigger only goes off on or after a
TaskInstance
- reaches an allowed state (i.e. SUCCESS).
- """
- trigger_start_time = utcnow()
- dag = DAG(self.DAG_ID, schedule=None,
start_date=timezone.datetime(2022, 1, 1))
- dag_run = DagRun(
- dag_id=dag.dag_id,
- run_type="manual",
- execution_date=timezone.datetime(2022, 1, 1),
- run_id=self.RUN_ID,
- )
- session.add(dag_run)
- session.commit()
-
- external_task = EmptyOperator(task_id=self.TASK_ID, dag=dag)
- instance = TaskInstance(external_task, run_id=self.RUN_ID)
- session.add(instance)
- session.commit()
-
- with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger
has been deprecated"):
- trigger = TaskStateTrigger(
- dag_id=dag.dag_id,
- task_id=instance.task_id,
- states=self.STATES,
- execution_dates=[timezone.datetime(2022, 1, 1)],
- poll_interval=0.2,
- trigger_start_time=trigger_start_time,
- )
-
- task = asyncio.create_task(trigger.run().__anext__())
- await asyncio.sleep(0.5)
-
- # It should not have produced a result
- assert task.done() is False
-
- # Progress the task to a "success" state so that run() yields a
TriggerEvent
- instance.state = TaskInstanceState.SUCCESS
- session.commit()
- await asyncio.sleep(0.5)
- assert task.done() is True
-
- # Prevents error when task is destroyed while in "pending" state
- asyncio.get_event_loop().stop()
-
- @mock.patch("airflow.triggers.external_task.utcnow")
- @pytest.mark.asyncio
- async def test_task_state_trigger_timeout(self, mock_utcnow):
- trigger_start_time = utcnow()
- mock_utcnow.return_value = trigger_start_time +
datetime.timedelta(seconds=61)
-
- with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger
has been deprecated"):
- trigger = TaskStateTrigger(
- dag_id="dag1",
- task_id="task1",
- states=self.STATES,
- execution_dates=[timezone.datetime(2022, 1, 1)],
- poll_interval=0.2,
- trigger_start_time=trigger_start_time,
- )
-
- trigger.count_running_dags = mock.AsyncMock()
- trigger.count_running_dags.return_value = 0
-
- gen = trigger.run()
- task = asyncio.create_task(gen.__anext__())
- await task
-
- result = task.result()
- assert isinstance(result, TriggerEvent)
- assert result.payload == {"status": "timeout"}
- assert task.done() is True
-
- # test that it returns after yielding
- with pytest.raises(StopAsyncIteration):
- await gen.__anext__()
-
- @mock.patch("airflow.triggers.external_task.utcnow")
- @mock.patch("airflow.triggers.external_task.asyncio.sleep")
- @pytest.mark.asyncio
- async def test_task_state_trigger_timeout_sleep_success(self, mock_sleep,
mock_utcnow):
- trigger_start_time = utcnow()
- mock_utcnow.return_value = trigger_start_time +
datetime.timedelta(seconds=20)
-
- with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger
has been deprecated"):
- trigger = TaskStateTrigger(
- dag_id="dag1",
- task_id="task1",
- states=self.STATES,
- execution_dates=[timezone.datetime(2022, 1, 1)],
- poll_interval=0.2,
- trigger_start_time=trigger_start_time,
- )
-
- trigger.count_running_dags = mock.AsyncMock()
- trigger.count_running_dags.return_value = 0
-
- trigger.count_tasks = mock.AsyncMock()
- trigger.count_tasks.return_value = 1
-
- gen = trigger.run()
- task = asyncio.create_task(gen.__anext__())
- await task
-
- mock_sleep.assert_awaited()
- assert mock_sleep.await_count == 1
-
- result = task.result()
- assert isinstance(result, TriggerEvent)
- assert result.payload == {"status": "success"}
- assert task.done() is True
-
- # test that it returns after yielding
- with pytest.raises(StopAsyncIteration):
- await gen.__anext__()
-
- @mock.patch("airflow.triggers.external_task.utcnow")
- @mock.patch("airflow.triggers.external_task.asyncio.sleep")
- @pytest.mark.asyncio
- async def test_task_state_trigger_failed_exception(self, mock_sleep,
mock_utcnow):
- """
- Asserts that the TaskStateTrigger only goes off on or after a
TaskInstance
- reaches an allowed state (i.e. SUCCESS).
- """
- trigger_start_time = utcnow()
- mock_utcnow.return_value = +datetime.timedelta(seconds=61)
-
- mock_utcnow.side_effect = [
- trigger_start_time,
- trigger_start_time + datetime.timedelta(seconds=20),
- ]
-
- with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger
has been deprecated"):
- trigger = TaskStateTrigger(
- dag_id="dag1",
- task_id="task1",
- states=self.STATES,
- execution_dates=[timezone.datetime(2022, 1, 1)],
- poll_interval=0.2,
- trigger_start_time=trigger_start_time,
- )
-
- trigger.count_running_dags = mock.AsyncMock()
- trigger.count_running_dags.side_effect = [SQLAlchemyError]
-
- gen = trigger.run()
- task = asyncio.create_task(gen.__anext__())
- await task
-
- result = task.result()
- assert isinstance(result, TriggerEvent)
- assert result.payload == {"status": "failed"}
- assert task.done() is True
-
- def test_serialization(self):
- """
- Asserts that the TaskStateTrigger correctly serializes its arguments
- and classpath.
- """
- trigger_start_time = utcnow()
- with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger
has been deprecated"):
- trigger = TaskStateTrigger(
- dag_id=self.DAG_ID,
- task_id=self.TASK_ID,
- states=self.STATES,
- execution_dates=[timezone.datetime(2022, 1, 1)],
- poll_interval=5,
- trigger_start_time=trigger_start_time,
- )
- classpath, kwargs = trigger.serialize()
- assert classpath == "airflow.triggers.external_task.TaskStateTrigger"
- assert kwargs == {
- "dag_id": self.DAG_ID,
- "task_id": self.TASK_ID,
- "states": self.STATES,
- "execution_dates": [timezone.datetime(2022, 1, 1)],
- "poll_interval": 5,
- "trigger_start_time": trigger_start_time,
- }
-
-
class TestDagStateTrigger:
DAG_ID = "test_dag_state_trigger"
RUN_ID = "external_task_run_id"