This is an automated email from the ASF dual-hosted git repository.

taragolis pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e6eec0cfad Use async db calls in WorkflowTrigger (#38689)
e6eec0cfad is described below

commit e6eec0cfad424e402fe2a03b42818e706f0685ba
Author: Steven Schaerer <53116297+stevenschae...@users.noreply.github.com>
AuthorDate: Thu Apr 4 21:55:28 2024 +0200

    Use async db calls in WorkflowTrigger (#38689)
    
    * Use async db calls in WorkflowTrigger
    
    * address PR comments
    
    * deprecate TaskStateTrigger with proper category
---
 airflow/triggers/external_task.py        |  46 +++++-----
 contributing-docs/testing/unit_tests.rst |   2 +-
 tests/triggers/test_external_task.py     | 153 ++++++++++++++++++++-----------
 3 files changed, 123 insertions(+), 78 deletions(-)

diff --git a/airflow/triggers/external_task.py 
b/airflow/triggers/external_task.py
index 5c7361a15b..a5de817c35 100644
--- a/airflow/triggers/external_task.py
+++ b/airflow/triggers/external_task.py
@@ -21,8 +21,10 @@ import typing
 from typing import Any
 
 from asgiref.sync import sync_to_async
+from deprecated import deprecated
 from sqlalchemy import func
 
+from airflow.exceptions import RemovedInAirflow3Warning
 from airflow.models import DagRun, TaskInstance
 from airflow.triggers.base import BaseTrigger, TriggerEvent
 from airflow.utils.sensor_helper import _get_count
@@ -98,13 +100,7 @@ class WorkflowTrigger(BaseTrigger):
         """Check periodically tasks, task group or dag status."""
         while True:
             if self.failed_states:
-                failed_count = _get_count(
-                    self.execution_dates,
-                    self.external_task_ids,
-                    self.external_task_group_id,
-                    self.external_dag_id,
-                    self.failed_states,
-                )
+                failed_count = await self._get_count(self.failed_states)
                 if failed_count > 0:
                     yield TriggerEvent({"status": "failed"})
                     return
@@ -112,30 +108,38 @@ class WorkflowTrigger(BaseTrigger):
                     yield TriggerEvent({"status": "success"})
                     return
             if self.skipped_states:
-                skipped_count = _get_count(
-                    self.execution_dates,
-                    self.external_task_ids,
-                    self.external_task_group_id,
-                    self.external_dag_id,
-                    self.skipped_states,
-                )
+                skipped_count = await self._get_count(self.skipped_states)
                 if skipped_count > 0:
                     yield TriggerEvent({"status": "skipped"})
                     return
-            allowed_count = _get_count(
-                self.execution_dates,
-                self.external_task_ids,
-                self.external_task_group_id,
-                self.external_dag_id,
-                self.allowed_states,
-            )
+            allowed_count = await self._get_count(self.allowed_states)
             if allowed_count == len(self.execution_dates):
                 yield TriggerEvent({"status": "success"})
                 return
             self.log.info("Sleeping for %s seconds", self.poke_interval)
             await asyncio.sleep(self.poke_interval)
 
+    @sync_to_async
+    def _get_count(self, states: typing.Iterable[str] | None) -> int:
+        """
+        Get the count of records against dttm filter and states. Async wrapper 
for _get_count.
+
+        :param states: task or dag states
+        :return The count of records.
+        """
+        return _get_count(
+            dttm_filter=self.execution_dates,
+            external_task_ids=self.external_task_ids,
+            external_task_group_id=self.external_task_group_id,
+            external_dag_id=self.external_dag_id,
+            states=states,
+        )
+
 
+@deprecated(
+    reason="TaskStateTrigger has been deprecated and will be removed in 
future.",
+    category=RemovedInAirflow3Warning,
+)
 class TaskStateTrigger(BaseTrigger):
     """
     Waits asynchronously for a task in a different DAG to complete for a 
specific logical date.
diff --git a/contributing-docs/testing/unit_tests.rst 
b/contributing-docs/testing/unit_tests.rst
index 56e9cb1799..57d2949520 100644
--- a/contributing-docs/testing/unit_tests.rst
+++ b/contributing-docs/testing/unit_tests.rst
@@ -66,7 +66,7 @@ For avoid this make sure:
 .. code-block:: python
 
     def test_deprecated_argument():
-        with pytest.warn(AirflowProviderDeprecationWarning, match="expected 
warning pattern"):
+        with pytest.warns(AirflowProviderDeprecationWarning, match="expected 
warning pattern"):
             SomeDeprecatedClass(foo="bar", spam="egg")
 
 
diff --git a/tests/triggers/test_external_task.py 
b/tests/triggers/test_external_task.py
index fe773049fa..8ce6d89a3a 100644
--- a/tests/triggers/test_external_task.py
+++ b/tests/triggers/test_external_task.py
@@ -18,11 +18,13 @@ from __future__ import annotations
 
 import asyncio
 import datetime
+import time
 from unittest import mock
 
 import pytest
 from sqlalchemy.exc import SQLAlchemyError
 
+from airflow.exceptions import RemovedInAirflow3Warning
 from airflow.models.dag import DAG
 from airflow.models.dagrun import DagRun
 from airflow.models.taskinstance import TaskInstance
@@ -41,11 +43,10 @@ class TestWorkflowTrigger:
     STATES = ["success", "fail"]
 
     @mock.patch("airflow.triggers.external_task._get_count")
-    @mock.patch("asyncio.sleep")
     @pytest.mark.asyncio
-    async def test_task_workflow_trigger_success(self, mock_sleep, 
mock_get_count):
+    async def test_task_workflow_trigger_success(self, mock_get_count):
         """check the db count get called correctly."""
-        mock_get_count.return_value = 1
+        mock_get_count.side_effect = mocked_get_count
         trigger = WorkflowTrigger(
             external_dag_id=self.DAG_ID,
             execution_dates=[timezone.datetime(2022, 1, 1)],
@@ -54,19 +55,29 @@ class TestWorkflowTrigger:
             poke_interval=0.2,
         )
 
-        generator = trigger.run()
-        await generator.asend(None)
+        gen = trigger.run()
+        trigger_task = asyncio.create_task(gen.__anext__())
+        fake_task = asyncio.create_task(fake_async_fun())
+        await trigger_task
+        assert fake_task.done()  # confirm that get_count is done in an async 
fashion
+        assert trigger_task.done()
+        result = trigger_task.result()
+        assert result.payload == {"status": "success"}
         mock_get_count.assert_called_once_with(
-            [timezone.datetime(2022, 1, 1)], ["external_task_op"], None, 
"external_task", ["success", "fail"]
+            dttm_filter=[timezone.datetime(2022, 1, 1)],
+            external_task_ids=["external_task_op"],
+            external_task_group_id=None,
+            external_dag_id="external_task",
+            states=["success", "fail"],
         )
         # test that it returns after yielding
         with pytest.raises(StopAsyncIteration):
-            await generator.__anext__()
+            await gen.__anext__()
 
     @mock.patch("airflow.triggers.external_task._get_count")
     @pytest.mark.asyncio
     async def test_task_workflow_trigger_failed(self, mock_get_count):
-        mock_get_count.return_value = 1
+        mock_get_count.side_effect = mocked_get_count
         trigger = WorkflowTrigger(
             external_dag_id=self.DAG_ID,
             execution_dates=[timezone.datetime(2022, 1, 1)],
@@ -77,13 +88,19 @@ class TestWorkflowTrigger:
 
         gen = trigger.run()
         trigger_task = asyncio.create_task(gen.__anext__())
+        fake_task = asyncio.create_task(fake_async_fun())
         await trigger_task
-        assert trigger_task.done() is True
+        assert fake_task.done()  # confirm that get_count is done in an async 
fashion
+        assert trigger_task.done()
         result = trigger_task.result()
         assert isinstance(result, TriggerEvent)
         assert result.payload == {"status": "failed"}
         mock_get_count.assert_called_once_with(
-            [timezone.datetime(2022, 1, 1)], ["external_task_op"], None, 
"external_task", ["success", "fail"]
+            dttm_filter=[timezone.datetime(2022, 1, 1)],
+            external_task_ids=["external_task_op"],
+            external_task_group_id=None,
+            external_dag_id="external_task",
+            states=["success", "fail"],
         )
         # test that it returns after yielding
         with pytest.raises(StopAsyncIteration):
@@ -104,12 +121,16 @@ class TestWorkflowTrigger:
         gen = trigger.run()
         trigger_task = asyncio.create_task(gen.__anext__())
         await trigger_task
-        assert trigger_task.done() is True
+        assert trigger_task.done()
         result = trigger_task.result()
         assert isinstance(result, TriggerEvent)
         assert result.payload == {"status": "success"}
         mock_get_count.assert_called_once_with(
-            [timezone.datetime(2022, 1, 1)], ["external_task_op"], None, 
"external_task", ["success", "fail"]
+            dttm_filter=[timezone.datetime(2022, 1, 1)],
+            external_task_ids=["external_task_op"],
+            external_task_group_id=None,
+            external_dag_id="external_task",
+            states=["success", "fail"],
         )
         # test that it returns after yielding
         with pytest.raises(StopAsyncIteration):
@@ -118,7 +139,7 @@ class TestWorkflowTrigger:
     @mock.patch("airflow.triggers.external_task._get_count")
     @pytest.mark.asyncio
     async def test_task_workflow_trigger_skipped(self, mock_get_count):
-        mock_get_count.return_value = 1
+        mock_get_count.side_effect = mocked_get_count
         trigger = WorkflowTrigger(
             external_dag_id=self.DAG_ID,
             execution_dates=[timezone.datetime(2022, 1, 1)],
@@ -129,13 +150,19 @@ class TestWorkflowTrigger:
 
         gen = trigger.run()
         trigger_task = asyncio.create_task(gen.__anext__())
+        fake_task = asyncio.create_task(fake_async_fun())
         await trigger_task
-        assert trigger_task.done() is True
+        assert fake_task.done()  # confirm that get_count is done in an async 
fashion
+        assert trigger_task.done()
         result = trigger_task.result()
         assert isinstance(result, TriggerEvent)
         assert result.payload == {"status": "skipped"}
         mock_get_count.assert_called_once_with(
-            [timezone.datetime(2022, 1, 1)], ["external_task_op"], None, 
"external_task", ["success", "fail"]
+            dttm_filter=[timezone.datetime(2022, 1, 1)],
+            external_task_ids=["external_task_op"],
+            external_task_group_id=None,
+            external_dag_id="external_task",
+            states=["success", "fail"],
         )
 
     @mock.patch("airflow.triggers.external_task._get_count")
@@ -153,7 +180,7 @@ class TestWorkflowTrigger:
         gen = trigger.run()
         trigger_task = asyncio.create_task(gen.__anext__())
         await trigger_task
-        assert trigger_task.done() is True
+        assert trigger_task.done()
         result = trigger_task.result()
         assert isinstance(result, TriggerEvent)
         assert result.payload == {"status": "success"}
@@ -222,14 +249,15 @@ class TestTaskStateTrigger:
         session.add(instance)
         session.commit()
 
-        trigger = TaskStateTrigger(
-            dag_id=dag.dag_id,
-            task_id=instance.task_id,
-            states=self.STATES,
-            execution_dates=[timezone.datetime(2022, 1, 1)],
-            poll_interval=0.2,
-            trigger_start_time=trigger_start_time,
-        )
+        with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger 
has been deprecated"):
+            trigger = TaskStateTrigger(
+                dag_id=dag.dag_id,
+                task_id=instance.task_id,
+                states=self.STATES,
+                execution_dates=[timezone.datetime(2022, 1, 1)],
+                poll_interval=0.2,
+                trigger_start_time=trigger_start_time,
+            )
 
         task = asyncio.create_task(trigger.run().__anext__())
         await asyncio.sleep(0.5)
@@ -252,14 +280,15 @@ class TestTaskStateTrigger:
         trigger_start_time = utcnow()
         mock_utcnow.return_value = trigger_start_time + 
datetime.timedelta(seconds=61)
 
-        trigger = TaskStateTrigger(
-            dag_id="dag1",
-            task_id="task1",
-            states=self.STATES,
-            execution_dates=[timezone.datetime(2022, 1, 1)],
-            poll_interval=0.2,
-            trigger_start_time=trigger_start_time,
-        )
+        with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger 
has been deprecated"):
+            trigger = TaskStateTrigger(
+                dag_id="dag1",
+                task_id="task1",
+                states=self.STATES,
+                execution_dates=[timezone.datetime(2022, 1, 1)],
+                poll_interval=0.2,
+                trigger_start_time=trigger_start_time,
+            )
 
         trigger.count_running_dags = mock.AsyncMock()
         trigger.count_running_dags.return_value = 0
@@ -284,14 +313,15 @@ class TestTaskStateTrigger:
         trigger_start_time = utcnow()
         mock_utcnow.return_value = trigger_start_time + 
datetime.timedelta(seconds=20)
 
-        trigger = TaskStateTrigger(
-            dag_id="dag1",
-            task_id="task1",
-            states=self.STATES,
-            execution_dates=[timezone.datetime(2022, 1, 1)],
-            poll_interval=0.2,
-            trigger_start_time=trigger_start_time,
-        )
+        with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger 
has been deprecated"):
+            trigger = TaskStateTrigger(
+                dag_id="dag1",
+                task_id="task1",
+                states=self.STATES,
+                execution_dates=[timezone.datetime(2022, 1, 1)],
+                poll_interval=0.2,
+                trigger_start_time=trigger_start_time,
+            )
 
         trigger.count_running_dags = mock.AsyncMock()
         trigger.count_running_dags.return_value = 0
@@ -331,14 +361,15 @@ class TestTaskStateTrigger:
             trigger_start_time + datetime.timedelta(seconds=20),
         ]
 
-        trigger = TaskStateTrigger(
-            dag_id="dag1",
-            task_id="task1",
-            states=self.STATES,
-            execution_dates=[timezone.datetime(2022, 1, 1)],
-            poll_interval=0.2,
-            trigger_start_time=trigger_start_time,
-        )
+        with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger 
has been deprecated"):
+            trigger = TaskStateTrigger(
+                dag_id="dag1",
+                task_id="task1",
+                states=self.STATES,
+                execution_dates=[timezone.datetime(2022, 1, 1)],
+                poll_interval=0.2,
+                trigger_start_time=trigger_start_time,
+            )
 
         trigger.count_running_dags = mock.AsyncMock()
         trigger.count_running_dags.side_effect = [SQLAlchemyError]
@@ -358,14 +389,15 @@ class TestTaskStateTrigger:
         and classpath.
         """
         trigger_start_time = utcnow()
-        trigger = TaskStateTrigger(
-            dag_id=self.DAG_ID,
-            task_id=self.TASK_ID,
-            states=self.STATES,
-            execution_dates=[timezone.datetime(2022, 1, 1)],
-            poll_interval=5,
-            trigger_start_time=trigger_start_time,
-        )
+        with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger 
has been deprecated"):
+            trigger = TaskStateTrigger(
+                dag_id=self.DAG_ID,
+                task_id=self.TASK_ID,
+                states=self.STATES,
+                execution_dates=[timezone.datetime(2022, 1, 1)],
+                poll_interval=5,
+                trigger_start_time=trigger_start_time,
+            )
         classpath, kwargs = trigger.serialize()
         assert classpath == "airflow.triggers.external_task.TaskStateTrigger"
         assert kwargs == {
@@ -438,3 +470,12 @@ class TestDagStateTrigger:
             "execution_dates": [timezone.datetime(2022, 1, 1)],
             "poll_interval": 5,
         }
+
+
+def mocked_get_count(*args, **kwargs):
+    time.sleep(0.0001)
+    return 1
+
+
+async def fake_async_fun():
+    await asyncio.sleep(0.00005)

Reply via email to