This is an automated email from the ASF dual-hosted git repository.
taragolis pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e6eec0cfad Use async db calls in WorkflowTrigger (#38689)
e6eec0cfad is described below
commit e6eec0cfad424e402fe2a03b42818e706f0685ba
Author: Steven Schaerer <[email protected]>
AuthorDate: Thu Apr 4 21:55:28 2024 +0200
Use async db calls in WorkflowTrigger (#38689)
* Use async db calls in WorkflowTrigger
* address PR comments
* deprecate TaskStateTrigger with proper category
---
airflow/triggers/external_task.py | 46 +++++-----
contributing-docs/testing/unit_tests.rst | 2 +-
tests/triggers/test_external_task.py | 153 ++++++++++++++++++++-----------
3 files changed, 123 insertions(+), 78 deletions(-)
diff --git a/airflow/triggers/external_task.py
b/airflow/triggers/external_task.py
index 5c7361a15b..a5de817c35 100644
--- a/airflow/triggers/external_task.py
+++ b/airflow/triggers/external_task.py
@@ -21,8 +21,10 @@ import typing
from typing import Any
from asgiref.sync import sync_to_async
+from deprecated import deprecated
from sqlalchemy import func
+from airflow.exceptions import RemovedInAirflow3Warning
from airflow.models import DagRun, TaskInstance
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.sensor_helper import _get_count
@@ -98,13 +100,7 @@ class WorkflowTrigger(BaseTrigger):
"""Check periodically tasks, task group or dag status."""
while True:
if self.failed_states:
- failed_count = _get_count(
- self.execution_dates,
- self.external_task_ids,
- self.external_task_group_id,
- self.external_dag_id,
- self.failed_states,
- )
+ failed_count = await self._get_count(self.failed_states)
if failed_count > 0:
yield TriggerEvent({"status": "failed"})
return
@@ -112,30 +108,38 @@ class WorkflowTrigger(BaseTrigger):
yield TriggerEvent({"status": "success"})
return
if self.skipped_states:
- skipped_count = _get_count(
- self.execution_dates,
- self.external_task_ids,
- self.external_task_group_id,
- self.external_dag_id,
- self.skipped_states,
- )
+ skipped_count = await self._get_count(self.skipped_states)
if skipped_count > 0:
yield TriggerEvent({"status": "skipped"})
return
- allowed_count = _get_count(
- self.execution_dates,
- self.external_task_ids,
- self.external_task_group_id,
- self.external_dag_id,
- self.allowed_states,
- )
+ allowed_count = await self._get_count(self.allowed_states)
if allowed_count == len(self.execution_dates):
yield TriggerEvent({"status": "success"})
return
self.log.info("Sleeping for %s seconds", self.poke_interval)
await asyncio.sleep(self.poke_interval)
+ @sync_to_async
+ def _get_count(self, states: typing.Iterable[str] | None) -> int:
+ """
+ Get the count of records against dttm filter and states. Async wrapper
for _get_count.
+
+ :param states: task or dag states
+ :return The count of records.
+ """
+ return _get_count(
+ dttm_filter=self.execution_dates,
+ external_task_ids=self.external_task_ids,
+ external_task_group_id=self.external_task_group_id,
+ external_dag_id=self.external_dag_id,
+ states=states,
+ )
+
+@deprecated(
+ reason="TaskStateTrigger has been deprecated and will be removed in
future.",
+ category=RemovedInAirflow3Warning,
+)
class TaskStateTrigger(BaseTrigger):
"""
Waits asynchronously for a task in a different DAG to complete for a
specific logical date.
diff --git a/contributing-docs/testing/unit_tests.rst
b/contributing-docs/testing/unit_tests.rst
index 56e9cb1799..57d2949520 100644
--- a/contributing-docs/testing/unit_tests.rst
+++ b/contributing-docs/testing/unit_tests.rst
@@ -66,7 +66,7 @@ For avoid this make sure:
.. code-block:: python
def test_deprecated_argument():
- with pytest.warn(AirflowProviderDeprecationWarning, match="expected
warning pattern"):
+ with pytest.warns(AirflowProviderDeprecationWarning, match="expected
warning pattern"):
SomeDeprecatedClass(foo="bar", spam="egg")
diff --git a/tests/triggers/test_external_task.py
b/tests/triggers/test_external_task.py
index fe773049fa..8ce6d89a3a 100644
--- a/tests/triggers/test_external_task.py
+++ b/tests/triggers/test_external_task.py
@@ -18,11 +18,13 @@ from __future__ import annotations
import asyncio
import datetime
+import time
from unittest import mock
import pytest
from sqlalchemy.exc import SQLAlchemyError
+from airflow.exceptions import RemovedInAirflow3Warning
from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
@@ -41,11 +43,10 @@ class TestWorkflowTrigger:
STATES = ["success", "fail"]
@mock.patch("airflow.triggers.external_task._get_count")
- @mock.patch("asyncio.sleep")
@pytest.mark.asyncio
- async def test_task_workflow_trigger_success(self, mock_sleep,
mock_get_count):
+ async def test_task_workflow_trigger_success(self, mock_get_count):
"""check the db count get called correctly."""
- mock_get_count.return_value = 1
+ mock_get_count.side_effect = mocked_get_count
trigger = WorkflowTrigger(
external_dag_id=self.DAG_ID,
execution_dates=[timezone.datetime(2022, 1, 1)],
@@ -54,19 +55,29 @@ class TestWorkflowTrigger:
poke_interval=0.2,
)
- generator = trigger.run()
- await generator.asend(None)
+ gen = trigger.run()
+ trigger_task = asyncio.create_task(gen.__anext__())
+ fake_task = asyncio.create_task(fake_async_fun())
+ await trigger_task
+ assert fake_task.done() # confirm that get_count is done in an async
fashion
+ assert trigger_task.done()
+ result = trigger_task.result()
+ assert result.payload == {"status": "success"}
mock_get_count.assert_called_once_with(
- [timezone.datetime(2022, 1, 1)], ["external_task_op"], None,
"external_task", ["success", "fail"]
+ dttm_filter=[timezone.datetime(2022, 1, 1)],
+ external_task_ids=["external_task_op"],
+ external_task_group_id=None,
+ external_dag_id="external_task",
+ states=["success", "fail"],
)
# test that it returns after yielding
with pytest.raises(StopAsyncIteration):
- await generator.__anext__()
+ await gen.__anext__()
@mock.patch("airflow.triggers.external_task._get_count")
@pytest.mark.asyncio
async def test_task_workflow_trigger_failed(self, mock_get_count):
- mock_get_count.return_value = 1
+ mock_get_count.side_effect = mocked_get_count
trigger = WorkflowTrigger(
external_dag_id=self.DAG_ID,
execution_dates=[timezone.datetime(2022, 1, 1)],
@@ -77,13 +88,19 @@ class TestWorkflowTrigger:
gen = trigger.run()
trigger_task = asyncio.create_task(gen.__anext__())
+ fake_task = asyncio.create_task(fake_async_fun())
await trigger_task
- assert trigger_task.done() is True
+ assert fake_task.done() # confirm that get_count is done in an async
fashion
+ assert trigger_task.done()
result = trigger_task.result()
assert isinstance(result, TriggerEvent)
assert result.payload == {"status": "failed"}
mock_get_count.assert_called_once_with(
- [timezone.datetime(2022, 1, 1)], ["external_task_op"], None,
"external_task", ["success", "fail"]
+ dttm_filter=[timezone.datetime(2022, 1, 1)],
+ external_task_ids=["external_task_op"],
+ external_task_group_id=None,
+ external_dag_id="external_task",
+ states=["success", "fail"],
)
# test that it returns after yielding
with pytest.raises(StopAsyncIteration):
@@ -104,12 +121,16 @@ class TestWorkflowTrigger:
gen = trigger.run()
trigger_task = asyncio.create_task(gen.__anext__())
await trigger_task
- assert trigger_task.done() is True
+ assert trigger_task.done()
result = trigger_task.result()
assert isinstance(result, TriggerEvent)
assert result.payload == {"status": "success"}
mock_get_count.assert_called_once_with(
- [timezone.datetime(2022, 1, 1)], ["external_task_op"], None,
"external_task", ["success", "fail"]
+ dttm_filter=[timezone.datetime(2022, 1, 1)],
+ external_task_ids=["external_task_op"],
+ external_task_group_id=None,
+ external_dag_id="external_task",
+ states=["success", "fail"],
)
# test that it returns after yielding
with pytest.raises(StopAsyncIteration):
@@ -118,7 +139,7 @@ class TestWorkflowTrigger:
@mock.patch("airflow.triggers.external_task._get_count")
@pytest.mark.asyncio
async def test_task_workflow_trigger_skipped(self, mock_get_count):
- mock_get_count.return_value = 1
+ mock_get_count.side_effect = mocked_get_count
trigger = WorkflowTrigger(
external_dag_id=self.DAG_ID,
execution_dates=[timezone.datetime(2022, 1, 1)],
@@ -129,13 +150,19 @@ class TestWorkflowTrigger:
gen = trigger.run()
trigger_task = asyncio.create_task(gen.__anext__())
+ fake_task = asyncio.create_task(fake_async_fun())
await trigger_task
- assert trigger_task.done() is True
+ assert fake_task.done() # confirm that get_count is done in an async
fashion
+ assert trigger_task.done()
result = trigger_task.result()
assert isinstance(result, TriggerEvent)
assert result.payload == {"status": "skipped"}
mock_get_count.assert_called_once_with(
- [timezone.datetime(2022, 1, 1)], ["external_task_op"], None,
"external_task", ["success", "fail"]
+ dttm_filter=[timezone.datetime(2022, 1, 1)],
+ external_task_ids=["external_task_op"],
+ external_task_group_id=None,
+ external_dag_id="external_task",
+ states=["success", "fail"],
)
@mock.patch("airflow.triggers.external_task._get_count")
@@ -153,7 +180,7 @@ class TestWorkflowTrigger:
gen = trigger.run()
trigger_task = asyncio.create_task(gen.__anext__())
await trigger_task
- assert trigger_task.done() is True
+ assert trigger_task.done()
result = trigger_task.result()
assert isinstance(result, TriggerEvent)
assert result.payload == {"status": "success"}
@@ -222,14 +249,15 @@ class TestTaskStateTrigger:
session.add(instance)
session.commit()
- trigger = TaskStateTrigger(
- dag_id=dag.dag_id,
- task_id=instance.task_id,
- states=self.STATES,
- execution_dates=[timezone.datetime(2022, 1, 1)],
- poll_interval=0.2,
- trigger_start_time=trigger_start_time,
- )
+ with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger
has been deprecated"):
+ trigger = TaskStateTrigger(
+ dag_id=dag.dag_id,
+ task_id=instance.task_id,
+ states=self.STATES,
+ execution_dates=[timezone.datetime(2022, 1, 1)],
+ poll_interval=0.2,
+ trigger_start_time=trigger_start_time,
+ )
task = asyncio.create_task(trigger.run().__anext__())
await asyncio.sleep(0.5)
@@ -252,14 +280,15 @@ class TestTaskStateTrigger:
trigger_start_time = utcnow()
mock_utcnow.return_value = trigger_start_time +
datetime.timedelta(seconds=61)
- trigger = TaskStateTrigger(
- dag_id="dag1",
- task_id="task1",
- states=self.STATES,
- execution_dates=[timezone.datetime(2022, 1, 1)],
- poll_interval=0.2,
- trigger_start_time=trigger_start_time,
- )
+ with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger
has been deprecated"):
+ trigger = TaskStateTrigger(
+ dag_id="dag1",
+ task_id="task1",
+ states=self.STATES,
+ execution_dates=[timezone.datetime(2022, 1, 1)],
+ poll_interval=0.2,
+ trigger_start_time=trigger_start_time,
+ )
trigger.count_running_dags = mock.AsyncMock()
trigger.count_running_dags.return_value = 0
@@ -284,14 +313,15 @@ class TestTaskStateTrigger:
trigger_start_time = utcnow()
mock_utcnow.return_value = trigger_start_time +
datetime.timedelta(seconds=20)
- trigger = TaskStateTrigger(
- dag_id="dag1",
- task_id="task1",
- states=self.STATES,
- execution_dates=[timezone.datetime(2022, 1, 1)],
- poll_interval=0.2,
- trigger_start_time=trigger_start_time,
- )
+ with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger
has been deprecated"):
+ trigger = TaskStateTrigger(
+ dag_id="dag1",
+ task_id="task1",
+ states=self.STATES,
+ execution_dates=[timezone.datetime(2022, 1, 1)],
+ poll_interval=0.2,
+ trigger_start_time=trigger_start_time,
+ )
trigger.count_running_dags = mock.AsyncMock()
trigger.count_running_dags.return_value = 0
@@ -331,14 +361,15 @@ class TestTaskStateTrigger:
trigger_start_time + datetime.timedelta(seconds=20),
]
- trigger = TaskStateTrigger(
- dag_id="dag1",
- task_id="task1",
- states=self.STATES,
- execution_dates=[timezone.datetime(2022, 1, 1)],
- poll_interval=0.2,
- trigger_start_time=trigger_start_time,
- )
+ with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger
has been deprecated"):
+ trigger = TaskStateTrigger(
+ dag_id="dag1",
+ task_id="task1",
+ states=self.STATES,
+ execution_dates=[timezone.datetime(2022, 1, 1)],
+ poll_interval=0.2,
+ trigger_start_time=trigger_start_time,
+ )
trigger.count_running_dags = mock.AsyncMock()
trigger.count_running_dags.side_effect = [SQLAlchemyError]
@@ -358,14 +389,15 @@ class TestTaskStateTrigger:
and classpath.
"""
trigger_start_time = utcnow()
- trigger = TaskStateTrigger(
- dag_id=self.DAG_ID,
- task_id=self.TASK_ID,
- states=self.STATES,
- execution_dates=[timezone.datetime(2022, 1, 1)],
- poll_interval=5,
- trigger_start_time=trigger_start_time,
- )
+ with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger
has been deprecated"):
+ trigger = TaskStateTrigger(
+ dag_id=self.DAG_ID,
+ task_id=self.TASK_ID,
+ states=self.STATES,
+ execution_dates=[timezone.datetime(2022, 1, 1)],
+ poll_interval=5,
+ trigger_start_time=trigger_start_time,
+ )
classpath, kwargs = trigger.serialize()
assert classpath == "airflow.triggers.external_task.TaskStateTrigger"
assert kwargs == {
@@ -438,3 +470,12 @@ class TestDagStateTrigger:
"execution_dates": [timezone.datetime(2022, 1, 1)],
"poll_interval": 5,
}
+
+
+def mocked_get_count(*args, **kwargs):
+ time.sleep(0.0001)
+ return 1
+
+
+async def fake_async_fun():
+ await asyncio.sleep(0.00005)