This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new dda3dcdcfc Add deferrable mode to ExternalTaskSensor (#29260)
dda3dcdcfc is described below
commit dda3dcdcfcc4d687f0b66e3cd993e0f5faf0caee
Author: Beata Kossakowska <[email protected]>
AuthorDate: Thu Jul 20 19:33:28 2023 +0200
Add deferrable mode to ExternalTaskSensor (#29260)
---------
Co-authored-by: Beata Kossakowska <[email protected]>
Co-authored-by: VladaZakharova <[email protected]>
---
airflow/sensors/external_task.py | 43 ++++++++++++
airflow/triggers/external_task.py | 80 ++++++++++++++++-----
.../howto/operator/external_task_sensor.rst | 9 +++
tests/sensors/test_external_task_sensor.py | 82 +++++++++++++++++++++-
tests/system/providers/core/__init__.py | 16 +++++
.../core/example_external_task_child_deferrable.py | 40 +++++++++++
.../example_external_task_parent_deferrable.py | 70 ++++++++++++++++++
tests/triggers/test_external_task.py | 6 ++
8 files changed, 328 insertions(+), 18 deletions(-)
diff --git a/airflow/sensors/external_task.py b/airflow/sensors/external_task.py
index b5c1c1321f..5e42820ffe 100644
--- a/airflow/sensors/external_task.py
+++ b/airflow/sensors/external_task.py
@@ -25,6 +25,7 @@ from typing import TYPE_CHECKING, Any, Callable, Collection,
Iterable
import attr
from sqlalchemy import func
+from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowSkipException,
RemovedInAirflow3Warning
from airflow.models.baseoperator import BaseOperatorLink
from airflow.models.dag import DagModel
@@ -33,11 +34,13 @@ from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.operators.empty import EmptyOperator
from airflow.sensors.base import BaseSensorOperator
+from airflow.triggers.external_task import TaskStateTrigger
from airflow.utils.file import correct_maybe_zipped
from airflow.utils.helpers import build_airflow_url_with_query
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import tuple_in_condition
from airflow.utils.state import State, TaskInstanceState
+from airflow.utils.timezone import utcnow
if TYPE_CHECKING:
from sqlalchemy.orm import Query, Session
@@ -126,6 +129,8 @@ class ExternalTaskSensor(BaseSensorOperator):
external_task_id is not None) or check if the DAG to wait for exists
(when
external_task_id is None), and immediately cease waiting if the
external task
or DAG does not exist (default value: False).
+ :param poll_interval: polling period in seconds to check for the status
+ :param deferrable: Run sensor in deferrable mode
"""
template_fields = ["external_dag_id", "external_task_id",
"external_task_ids", "external_task_group_id"]
@@ -145,9 +150,12 @@ class ExternalTaskSensor(BaseSensorOperator):
execution_delta: datetime.timedelta | None = None,
execution_date_fn: Callable | None = None,
check_existence: bool = False,
+ poll_interval: float = 2.0,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
**kwargs,
):
super().__init__(**kwargs)
+
self.allowed_states = list(allowed_states) if allowed_states else
[TaskInstanceState.SUCCESS.value]
self.skipped_states = list(skipped_states) if skipped_states else []
self.failed_states = list(failed_states) if failed_states else []
@@ -211,6 +219,8 @@ class ExternalTaskSensor(BaseSensorOperator):
self.external_task_group_id = external_task_group_id
self.check_existence = check_existence
self._has_checked_existence = False
+ self.deferrable = deferrable
+ self.poll_interval = poll_interval
def _get_dttm_filter(self, context):
if self.execution_delta:
@@ -318,6 +328,39 @@ class ExternalTaskSensor(BaseSensorOperator):
count_allowed = self.get_count(dttm_filter, session,
self.allowed_states)
return count_allowed == len(dttm_filter)
+ def execute(self, context: Context) -> None:
+ """
+ Airflow runs this method on the worker and defers using the triggers
+ if deferrable is set to True.
+ """
+ if not self.deferrable:
+ super().execute(context)
+ else:
+ self.defer(
+ trigger=TaskStateTrigger(
+ dag_id=self.external_dag_id,
+ task_id=self.external_task_id,
+ execution_dates=self._get_dttm_filter(context),
+ states=self.allowed_states,
+ trigger_start_time=utcnow(),
+ poll_interval=self.poll_interval,
+ ),
+ method_name="execute_complete",
+ )
+
+ def execute_complete(self, context, event=None):
+ """Callback for when the trigger fires - returns immediately."""
+ if event["status"] == "success":
+ self.log.info("External task %s has executed successfully.",
self.external_task_id)
+ return None
+ elif event["status"] == "timeout":
+ raise AirflowException("Dag was not started within 1 minute,
assuming fail.")
+ else:
+ raise AirflowException(
+ "Error occurred while trying to retrieve task status. Please,
check the "
+ "name of executed task and Dag."
+ )
+
def _check_for_existence(self, session) -> None:
dag_to_wait = DagModel.get_current(self.external_dag_id, session)
diff --git a/airflow/triggers/external_task.py
b/airflow/triggers/external_task.py
index e739c7a7cb..f179cba259 100644
--- a/airflow/triggers/external_task.py
+++ b/airflow/triggers/external_task.py
@@ -17,8 +17,8 @@
from __future__ import annotations
import asyncio
-import datetime
import typing
+from datetime import datetime
from asgiref.sync import sync_to_async
from sqlalchemy import func
@@ -27,7 +27,8 @@ from sqlalchemy.orm import Session
from airflow.models import DagRun, TaskInstance
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.state import DagRunState
+from airflow.utils.state import DagRunState, TaskInstanceState
+from airflow.utils.timezone import utcnow
class TaskStateTrigger(BaseTrigger):
@@ -36,20 +37,26 @@ class TaskStateTrigger(BaseTrigger):
:param dag_id: The dag_id that contains the task you want to wait for
:param task_id: The task_id that contains the task you want to
- wait for. If ``None`` (default value) the sensor waits for the DAG
+ wait for.
:param states: allowed states, default is ``['success']``
- :param execution_dates:
+ :param execution_dates: task execution time interval
:param poll_interval: The time interval in seconds to check the state.
The default value is 5 sec.
+ :param trigger_start_time: time in Datetime format when the trigger was
started. Is used
+ to control the execution of trigger to prevent infinite loop in case
if specified name
+ of the dag does not exist in database. It will wait period of time
equals _timeout_sec parameter
+ from the time, when the trigger was started and if the execution lasts
more time than expected,
+ the trigger will terminate with 'timeout' status.
"""
def __init__(
self,
dag_id: str,
- task_id: str,
- states: list[str],
- execution_dates: list[datetime.datetime],
- poll_interval: float = 5.0,
+ execution_dates: list[datetime],
+ trigger_start_time: datetime,
+ states: list[str] | None = None,
+ task_id: str | None = None,
+ poll_interval: float = 2.0,
):
super().__init__()
self.dag_id = dag_id
@@ -57,6 +64,9 @@ class TaskStateTrigger(BaseTrigger):
self.states = states
self.execution_dates = execution_dates
self.poll_interval = poll_interval
+ self.trigger_start_time = trigger_start_time
+ self.states = states if states else [TaskInstanceState.SUCCESS.value]
+ self._timeout_sec = 60
def serialize(self) -> tuple[str, dict[str, typing.Any]]:
"""Serializes TaskStateTrigger arguments and classpath."""
@@ -68,17 +78,52 @@ class TaskStateTrigger(BaseTrigger):
"states": self.states,
"execution_dates": self.execution_dates,
"poll_interval": self.poll_interval,
+ "trigger_start_time": self.trigger_start_time,
},
)
async def run(self) -> typing.AsyncIterator[TriggerEvent]:
- """Checks periodically in the database to see if the task exists and
has hit one of the states."""
+ """
+ Checks periodically in the database to see if the dag exists and is in
the running state. If found,
+ wait until the task specified will reach one of the expected states.
If dag with specified name was
+ not in the running state after _timeout_sec seconds after starting
execution process of the trigger,
+ terminate with status 'timeout'.
+ """
while True:
- # mypy confuses typing here
- num_tasks = await self.count_tasks() # type: ignore[call-arg]
- if num_tasks == len(self.execution_dates):
- yield TriggerEvent(True)
- await asyncio.sleep(self.poll_interval)
+ try:
+ delta = utcnow() - self.trigger_start_time
+ if delta.total_seconds() < self._timeout_sec:
+ # mypy confuses typing here
+ if await self.count_running_dags() == 0: # type:
ignore[call-arg]
+ self.log.info("Waiting for DAG to start execution...")
+ await asyncio.sleep(self.poll_interval)
+ else:
+ yield TriggerEvent({"status": "timeout"})
+ return
+ # mypy confuses typing here
+ if await self.count_tasks() == len(self.execution_dates): #
type: ignore[call-arg]
+ yield TriggerEvent({"status": "success"})
+ return
+ self.log.info("Task is still running, sleeping for %s
seconds...", self.poll_interval)
+ await asyncio.sleep(self.poll_interval)
+ except Exception:
+ yield TriggerEvent({"status": "failed"})
+ return
+
+ @sync_to_async
+ @provide_session
+ def count_running_dags(self, session: Session):
+ """Count how many dag instances in running state in the database."""
+ dags = (
+ session.query(func.count("*"))
+ .filter(
+ TaskInstance.dag_id == self.dag_id,
+ TaskInstance.execution_date.in_(self.execution_dates),
+ TaskInstance.state.in_(["running", "success"]),
+ )
+ .scalar()
+ )
+ return dags
@sync_to_async
@provide_session
@@ -112,7 +157,7 @@ class DagStateTrigger(BaseTrigger):
self,
dag_id: str,
states: list[DagRunState],
- execution_dates: list[datetime.datetime],
+ execution_dates: list[datetime],
poll_interval: float = 5.0,
):
super().__init__()
@@ -134,7 +179,10 @@ class DagStateTrigger(BaseTrigger):
)
async def run(self) -> typing.AsyncIterator[TriggerEvent]:
- """Checks periodically in the database to see if the dag run exists
and has hit one of the states."""
+ """
+ Checks periodically in the database to see if the dag run exists, and
has
+ hit one of the states yet, or not.
+ """
while True:
# mypy confuses typing here
num_dags = await self.count_dags() # type: ignore[call-arg]
diff --git a/docs/apache-airflow/howto/operator/external_task_sensor.rst
b/docs/apache-airflow/howto/operator/external_task_sensor.rst
index 923f8ec3d1..f6f53f87e5 100644
--- a/docs/apache-airflow/howto/operator/external_task_sensor.rst
+++ b/docs/apache-airflow/howto/operator/external_task_sensor.rst
@@ -53,6 +53,15 @@ via ``allowed_states`` and ``failed_states`` parameters.
:start-after: [START howto_operator_external_task_sensor]
:end-before: [END howto_operator_external_task_sensor]
+Also for this action you can use sensor in the deferrable mode:
+
+.. exampleinclude::
/../../tests/system/providers/core/example_external_task_parent_deferrable.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_external_task_async_sensor]
+ :end-before: [END howto_external_task_async_sensor]
+
+
ExternalTaskSensor with task_group dependency
---------------------------------------------
In Addition, we can also use the
:class:`~airflow.sensors.external_task.ExternalTaskSensor` to make tasks on a
DAG
diff --git a/tests/sensors/test_external_task_sensor.py
b/tests/sensors/test_external_task_sensor.py
index a5259084b1..e84b3f69f4 100644
--- a/tests/sensors/test_external_task_sensor.py
+++ b/tests/sensors/test_external_task_sensor.py
@@ -22,12 +22,13 @@ import os
import tempfile
import zipfile
from datetime import time, timedelta
+from unittest import mock
import pytest
from airflow import exceptions, settings
from airflow.decorators import task as task_deco
-from airflow.exceptions import AirflowException, AirflowSensorTimeout
+from airflow.exceptions import AirflowException, AirflowSensorTimeout,
TaskDeferred
from airflow.models import DagBag, DagRun, TaskInstance
from airflow.models.dag import DAG
from airflow.models.serialized_dag import SerializedDagModel
@@ -35,9 +36,14 @@ from airflow.models.xcom_arg import XComArg
from airflow.operators.bash import BashOperator
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import PythonOperator
-from airflow.sensors.external_task import ExternalTaskMarker,
ExternalTaskSensor, ExternalTaskSensorLink
+from airflow.sensors.external_task import (
+ ExternalTaskMarker,
+ ExternalTaskSensor,
+ ExternalTaskSensorLink,
+)
from airflow.sensors.time_sensor import TimeSensor
from airflow.serialization.serialized_objects import SerializedBaseOperator
+from airflow.triggers.external_task import TaskStateTrigger
from airflow.utils.hashlib_wrapper import md5
from airflow.utils.session import create_session, provide_session
from airflow.utils.state import DagRunState, State, TaskInstanceState
@@ -54,6 +60,9 @@ TEST_TASK_ID = "time_sensor_check"
TEST_TASK_ID_ALTERNATE = "time_sensor_check_alternate"
TEST_TASK_GROUP_ID = "time_sensor_group_id"
DEV_NULL = "/dev/null"
+TASK_ID = "external_task_sensor_check"
+EXTERNAL_DAG_ID = "child_dag" # DAG the external task sensor is waiting on
+EXTERNAL_TASK_ID = "child_task" # Task the external task sensor is waiting on
@pytest.fixture(autouse=True)
@@ -829,6 +838,75 @@ exit 0
)
+class TestExternalTaskAsyncSensor:
+ TASK_ID = "external_task_sensor_check"
+ EXTERNAL_DAG_ID = "child_dag" # DAG the external task sensor is waiting on
+ EXTERNAL_TASK_ID = "child_task" # Task the external task sensor is
waiting on
+
+ def test_defer_and_fire_task_state_trigger(self):
+ """
+ Asserts that a task is deferred and TaskStateTrigger will be fired
+ when the ExternalTaskAsyncSensor is provided with all required
arguments
+ (i.e. including the external_task_id).
+ """
+ sensor = ExternalTaskSensor(
+ task_id=TASK_ID,
+ external_task_id=EXTERNAL_TASK_ID,
+ external_dag_id=EXTERNAL_DAG_ID,
+ deferrable=True,
+ )
+
+ with pytest.raises(TaskDeferred) as exc:
+ sensor.execute(context=mock.MagicMock())
+
+ assert isinstance(exc.value.trigger, TaskStateTrigger), "Trigger is
not a TaskStateTrigger"
+
+ def test_defer_and_fire_failed_state_trigger(self):
+ """Tests that an AirflowException is raised in case of error event"""
+ sensor = ExternalTaskSensor(
+ task_id=TASK_ID,
+ external_task_id=EXTERNAL_TASK_ID,
+ external_dag_id=EXTERNAL_DAG_ID,
+ deferrable=True,
+ )
+
+ with pytest.raises(AirflowException):
+ sensor.execute_complete(
+ context=mock.MagicMock(), event={"status": "error", "message":
"test failure message"}
+ )
+
+ def test_defer_and_fire_timeout_state_trigger(self):
+ """Tests that an AirflowException is raised in case of timeout event"""
+ sensor = ExternalTaskSensor(
+ task_id=TASK_ID,
+ external_task_id=EXTERNAL_TASK_ID,
+ external_dag_id=EXTERNAL_DAG_ID,
+ deferrable=True,
+ )
+
+ with pytest.raises(AirflowException):
+ sensor.execute_complete(
+ context=mock.MagicMock(),
+ event={"status": "timeout", "message": "Dag was not started
within 1 minute, assuming fail."},
+ )
+
+ def test_defer_execute_check_correct_logging(self):
+ """Asserts that logging occurs as expected"""
+ sensor = ExternalTaskSensor(
+ task_id=TASK_ID,
+ external_task_id=EXTERNAL_TASK_ID,
+ external_dag_id=EXTERNAL_DAG_ID,
+ deferrable=True,
+ )
+
+ with mock.patch.object(sensor.log, "info") as mock_log_info:
+ sensor.execute_complete(
+ context=mock.MagicMock(),
+ event={"status": "success"},
+ )
+ mock_log_info.assert_called_with("External task %s has executed
successfully.", EXTERNAL_TASK_ID)
+
+
def test_external_task_sensor_check_zipped_dag_existence(dag_zip_maker):
with dag_zip_maker("test_external_task_sensor_check_existense.py") as
dagbag:
with create_session() as session:
diff --git a/tests/system/providers/core/__init__.py
b/tests/system/providers/core/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/system/providers/core/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git
a/tests/system/providers/core/example_external_task_child_deferrable.py
b/tests/system/providers/core/example_external_task_child_deferrable.py
new file mode 100644
index 0000000000..f75eb4f234
--- /dev/null
+++ b/tests/system/providers/core/example_external_task_child_deferrable.py
@@ -0,0 +1,40 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from datetime import datetime
+
+from airflow import DAG
+from airflow.operators.bash import BashOperator
+
+with DAG(
+ dag_id="child_dag",
+ start_date=datetime(2022, 1, 1),
+ schedule="@once",
+ catchup=False,
+ tags=["example", "async", "core"],
+) as dag:
+ dummy_task = BashOperator(
+ task_id="child_task",
+ bash_command="echo 1; sleep 1; echo 2; sleep 2; echo 3; sleep 3",
+ )
+
+
+from tests.system.utils import get_test_run
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)
diff --git
a/tests/system/providers/core/example_external_task_parent_deferrable.py
b/tests/system/providers/core/example_external_task_parent_deferrable.py
new file mode 100644
index 0000000000..7cec2ce138
--- /dev/null
+++ b/tests/system/providers/core/example_external_task_parent_deferrable.py
@@ -0,0 +1,70 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from airflow import DAG
+from airflow.operators.dummy import DummyOperator
+from airflow.operators.trigger_dagrun import TriggerDagRunOperator
+from airflow.sensors.external_task import ExternalTaskSensor
+from airflow.utils.timezone import datetime
+
+with DAG(
+ dag_id="example_external_task",
+ start_date=datetime(2022, 1, 1),
+ schedule="@once",
+ catchup=False,
+ tags=["example", "async", "core"],
+) as dag:
+ start = DummyOperator(task_id="start")
+
+ # [START howto_external_task_async_sensor]
+ external_task_sensor = ExternalTaskSensor(
+ task_id="parent_task_sensor",
+ external_task_id="child_task",
+ external_dag_id="child_dag",
+ deferrable=True,
+ )
+ # [END howto_external_task_async_sensor]
+
+ trigger_child_task = TriggerDagRunOperator(
+ task_id="trigger_child_task",
+ trigger_dag_id="child_dag",
+ allowed_states=[
+ "success",
+ "failed",
+ ],
+ execution_date="{{execution_date}}",
+ poke_interval=5,
+ reset_dag_run=True,
+ wait_for_completion=True,
+ )
+
+ end = DummyOperator(task_id="end")
+
+ start >> [trigger_child_task, external_task_sensor] >> end
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "teardown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+
+from tests.system.utils import get_test_run
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)
diff --git a/tests/triggers/test_external_task.py
b/tests/triggers/test_external_task.py
index e8a4a67ba5..a8569a9c55 100644
--- a/tests/triggers/test_external_task.py
+++ b/tests/triggers/test_external_task.py
@@ -26,6 +26,7 @@ from airflow.operators.empty import EmptyOperator
from airflow.triggers.external_task import DagStateTrigger, TaskStateTrigger
from airflow.utils import timezone
from airflow.utils.state import DagRunState, TaskInstanceState
+from airflow.utils.timezone import utcnow
class TestTaskStateTrigger:
@@ -40,6 +41,7 @@ class TestTaskStateTrigger:
Asserts that the TaskStateTrigger only goes off on or after a
TaskInstance
reaches an allowed state (i.e. SUCCESS).
"""
+ trigger_start_time = utcnow()
dag = DAG(self.DAG_ID, start_date=timezone.datetime(2022, 1, 1))
dag_run = DagRun(
dag_id=dag.dag_id,
@@ -61,6 +63,7 @@ class TestTaskStateTrigger:
states=self.STATES,
execution_dates=[timezone.datetime(2022, 1, 1)],
poll_interval=0.2,
+ trigger_start_time=trigger_start_time,
)
task = asyncio.create_task(trigger.run().__anext__())
@@ -83,12 +86,14 @@ class TestTaskStateTrigger:
Asserts that the TaskStateTrigger correctly serializes its arguments
and classpath.
"""
+ trigger_start_time = utcnow()
trigger = TaskStateTrigger(
dag_id=self.DAG_ID,
task_id=self.TASK_ID,
states=self.STATES,
execution_dates=[timezone.datetime(2022, 1, 1)],
poll_interval=5,
+ trigger_start_time=trigger_start_time,
)
classpath, kwargs = trigger.serialize()
assert classpath == "airflow.triggers.external_task.TaskStateTrigger"
@@ -98,6 +103,7 @@ class TestTaskStateTrigger:
"states": self.STATES,
"execution_dates": [timezone.datetime(2022, 1, 1)],
"poll_interval": 5,
+ "trigger_start_time": trigger_start_time,
}