This is an automated email from the ASF dual-hosted git repository.
phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e9a4bcaa16 Improve ExternalTaskSensor Async Implementation (#36916)
e9a4bcaa16 is described below
commit e9a4bcaa16cd7627c22469d9aec3fadceb3361ac
Author: Pankaj Singh <[email protected]>
AuthorDate: Thu Jan 25 16:05:51 2024 +0530
Improve ExternalTaskSensor Async Implementation (#36916)
---
airflow/sensors/external_task.py | 102 +++++++++---------------
airflow/triggers/external_task.py | 99 +++++++++++++++++++++++
airflow/utils/sensor_helper.py | 123 +++++++++++++++++++++++++++++
tests/sensors/test_external_task_sensor.py | 6 +-
tests/triggers/test_external_task.py | 55 ++++++++++++-
5 files changed, 315 insertions(+), 70 deletions(-)
diff --git a/airflow/sensors/external_task.py b/airflow/sensors/external_task.py
index 5a3353d916..32f5992001 100644
--- a/airflow/sensors/external_task.py
+++ b/airflow/sensors/external_task.py
@@ -23,27 +23,24 @@ import warnings
from typing import TYPE_CHECKING, Any, Callable, Collection, Iterable
import attr
-from sqlalchemy import func
from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowSkipException,
RemovedInAirflow3Warning
from airflow.models.baseoperatorlink import BaseOperatorLink
from airflow.models.dag import DagModel
from airflow.models.dagbag import DagBag
-from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.operators.empty import EmptyOperator
from airflow.sensors.base import BaseSensorOperator
-from airflow.triggers.external_task import TaskStateTrigger
+from airflow.triggers.external_task import WorkflowTrigger
from airflow.utils.file import correct_maybe_zipped
from airflow.utils.helpers import build_airflow_url_with_query
+from airflow.utils.sensor_helper import _get_count,
_get_external_task_group_task_ids
from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import tuple_in_condition
from airflow.utils.state import State, TaskInstanceState
-from airflow.utils.timezone import utcnow
if TYPE_CHECKING:
- from sqlalchemy.orm import Query, Session
+ from sqlalchemy.orm import Session
from airflow.models.baseoperator import BaseOperator
from airflow.models.taskinstancekey import TaskInstanceKey
@@ -351,13 +348,14 @@ class ExternalTaskSensor(BaseSensorOperator):
super().execute(context)
else:
self.defer(
- trigger=TaskStateTrigger(
- dag_id=self.external_dag_id,
- task_id=self.external_task_id,
+ timeout=self.execution_timeout,
+ trigger=WorkflowTrigger(
+ external_dag_id=self.external_dag_id,
+ external_task_ids=self.external_task_ids,
execution_dates=self._get_dttm_filter(context),
- states=self.allowed_states,
- trigger_start_time=utcnow(),
- poll_interval=self.poll_interval,
+ allowed_states=self.allowed_states,
+ poke_interval=self.poll_interval,
+ soft_fail=self.soft_fail,
),
method_name="execute_complete",
)
@@ -365,15 +363,17 @@ class ExternalTaskSensor(BaseSensorOperator):
def execute_complete(self, context, event=None):
"""Execute when the trigger fires - return immediately."""
if event["status"] == "success":
- self.log.info("External task %s has executed successfully.",
self.external_task_id)
- return None
- elif event["status"] == "timeout":
- raise AirflowException("Dag was not started within 1 minute,
assuming fail.")
+ self.log.info("External tasks %s has executed successfully.",
self.external_task_ids)
+ elif event["status"] == "skipped":
+ raise AirflowSkipException("External job has skipped skipping.")
else:
- raise AirflowException(
- "Error occurred while trying to retrieve task status. Please,
check the "
- "name of executed task and Dag."
- )
+ if self.soft_fail:
+ raise AirflowSkipException("External job has failed skipping.")
+ else:
+ raise AirflowException(
+ "Error occurred while trying to retrieve task status.
Please, check the "
+ "name of executed task and Dag."
+ )
def _check_for_existence(self, session) -> None:
dag_to_wait = DagModel.get_current(self.external_dag_id, session)
@@ -412,55 +412,25 @@ class ExternalTaskSensor(BaseSensorOperator):
:param states: task or dag states
:return: count of record against the filters
"""
- TI = TaskInstance
- DR = DagRun
- if not dttm_filter:
- return 0
-
- if self.external_task_ids:
- count = (
- self._count_query(TI, session, states, dttm_filter)
- .filter(TI.task_id.in_(self.external_task_ids))
- .scalar()
- ) / len(self.external_task_ids)
- elif self.external_task_group_id:
- external_task_group_task_ids =
self.get_external_task_group_task_ids(session, dttm_filter)
- if not external_task_group_task_ids:
- count = 0
- else:
- count = (
- self._count_query(TI, session, states, dttm_filter)
- .filter(tuple_in_condition((TI.task_id, TI.map_index),
external_task_group_task_ids))
- .scalar()
- ) / len(external_task_group_task_ids)
- else:
- count = self._count_query(DR, session, states,
dttm_filter).scalar()
- return count
-
- def _count_query(self, model, session, states, dttm_filter) -> Query:
- query = session.query(func.count()).filter(
- model.dag_id == self.external_dag_id,
- model.state.in_(states),
- model.execution_date.in_(dttm_filter),
+ warnings.warn(
+ "This method is deprecated and will be removed in future.",
DeprecationWarning, stacklevel=2
+ )
+ return _get_count(
+ dttm_filter,
+ self.external_task_ids,
+ self.external_task_group_id,
+ self.external_dag_id,
+ states,
+ session,
)
- return query
def get_external_task_group_task_ids(self, session, dttm_filter):
- refreshed_dag_info =
DagBag(read_dags_from_db=True).get_dag(self.external_dag_id, session)
- task_group =
refreshed_dag_info.task_group_dict.get(self.external_task_group_id)
-
- if task_group:
- group_tasks = session.query(TaskInstance).filter(
- TaskInstance.dag_id == self.external_dag_id,
- TaskInstance.task_id.in_(task.task_id for task in task_group),
- TaskInstance.execution_date.in_(dttm_filter),
- )
-
- return [(t.task_id, t.map_index) for t in group_tasks]
-
- # returning default task_id as group_id itself, this will avoid any
failure in case of
- # 'check_existence=False' and will fail on timeout
- return [(self.external_task_group_id, -1)]
+ warnings.warn(
+ "This method is deprecated and will be removed in future.",
DeprecationWarning, stacklevel=2
+ )
+ return _get_external_task_group_task_ids(
+ dttm_filter, self.external_task_group_id, self.external_dag_id,
session
+ )
def _handle_execution_date_fn(self, context) -> Any:
"""
diff --git a/airflow/triggers/external_task.py
b/airflow/triggers/external_task.py
index 269dbaebc8..98305ea4f1 100644
--- a/airflow/triggers/external_task.py
+++ b/airflow/triggers/external_task.py
@@ -18,12 +18,14 @@ from __future__ import annotations
import asyncio
import typing
+from typing import Any
from asgiref.sync import sync_to_async
from sqlalchemy import func
from airflow.models import DagRun, TaskInstance
from airflow.triggers.base import BaseTrigger, TriggerEvent
+from airflow.utils.sensor_helper import _get_count
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import TaskInstanceState
from airflow.utils.timezone import utcnow
@@ -36,6 +38,103 @@ if typing.TYPE_CHECKING:
from airflow.utils.state import DagRunState
+class WorkflowTrigger(BaseTrigger):
+ """
+ A trigger to monitor tasks, task group and dag execution in Apache Airflow.
+
+ :param external_dag_id: The ID of the external DAG.
+ :param execution_dates: A list of execution dates for the external DAG.
+ :param external_task_ids: A collection of external task IDs to wait for.
+ :param external_task_group_id: The ID of the external task group to wait
for.
+ :param failed_states: States considered as failed for external tasks.
+ :param skipped_states: States considered as skipped for external tasks.
+ :param allowed_states: States considered as successful for external tasks.
+ :param poke_interval: The interval (in seconds) for poking the external
tasks.
+ :param soft_fail: If True, the trigger will not fail the entire DAG on
external task failure.
+ """
+
+ def __init__(
+ self,
+ external_dag_id: str,
+ execution_dates: list,
+ external_task_ids: typing.Collection[str] | None = None,
+ external_task_group_id: str | None = None,
+ failed_states: typing.Iterable[str] | None = None,
+ skipped_states: typing.Iterable[str] | None = None,
+ allowed_states: typing.Iterable[str] | None = None,
+ poke_interval: float = 2.0,
+ soft_fail: bool = False,
+ **kwargs,
+ ):
+ self.external_dag_id = external_dag_id
+ self.external_task_ids = external_task_ids
+ self.external_task_group_id = external_task_group_id
+ self.failed_states = failed_states
+ self.skipped_states = skipped_states
+ self.allowed_states = allowed_states
+ self.execution_dates = execution_dates
+ self.poke_interval = poke_interval
+ self.soft_fail = soft_fail
+ super().__init__(**kwargs)
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serialize the trigger param and module path."""
+ return (
+ "airflow.triggers.external_task.WorkflowTrigger",
+ {
+ "external_dag_id": self.external_dag_id,
+ "external_task_ids": self.external_task_ids,
+ "external_task_group_id": self.external_task_group_id,
+ "failed_states": self.failed_states,
+ "skipped_states": self.skipped_states,
+ "allowed_states": self.allowed_states,
+ "execution_dates": self.execution_dates,
+ "poke_interval": self.poke_interval,
+ "soft_fail": self.soft_fail,
+ },
+ )
+
+ async def run(self) -> typing.AsyncIterator[TriggerEvent]:
+ """Check periodically tasks, task group or dag status."""
+ while True:
+ if self.failed_states:
+ failed_count = _get_count(
+ self.execution_dates,
+ self.external_task_ids,
+ self.external_task_group_id,
+ self.external_dag_id,
+ self.failed_states,
+ )
+ if failed_count > 0:
+ yield TriggerEvent({"status": "failed"})
+ return
+ else:
+ yield TriggerEvent({"status": "success"})
+ return
+ if self.skipped_states:
+ skipped_count = _get_count(
+ self.execution_dates,
+ self.external_task_ids,
+ self.external_task_group_id,
+ self.external_dag_id,
+ self.skipped_states,
+ )
+ if skipped_count > 0:
+ yield TriggerEvent({"status": "skipped"})
+ allowed_count = _get_count(
+ self.execution_dates,
+ self.external_task_ids,
+ self.external_task_group_id,
+ self.external_dag_id,
+ self.allowed_states,
+ )
+ if allowed_count == len(self.execution_dates):
+ yield TriggerEvent({"status": "success"})
+ return
+ self.log.info("Sleeping for %s seconds", self.poke_interval)
+ await asyncio.sleep(self.poke_interval)
+
+
class TaskStateTrigger(BaseTrigger):
"""
Waits asynchronously for a task in a different DAG to complete for a
specific logical date.
diff --git a/airflow/utils/sensor_helper.py b/airflow/utils/sensor_helper.py
new file mode 100644
index 0000000000..fe72a70041
--- /dev/null
+++ b/airflow/utils/sensor_helper.py
@@ -0,0 +1,123 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, cast
+
+from sqlalchemy import func, select
+
+from airflow.models import DagBag, DagRun, TaskInstance
+from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.utils.sqlalchemy import tuple_in_condition
+
+if TYPE_CHECKING:
+ from sqlalchemy.orm import Query, Session
+
+
+@provide_session
+def _get_count(
+ dttm_filter,
+ external_task_ids,
+ external_task_group_id,
+ external_dag_id,
+ states,
+ session: Session = NEW_SESSION,
+) -> int:
+ """
+ Get the count of records against dttm filter and states.
+
+ :param dttm_filter: date time filter for execution date
+ :param external_task_ids: The list of task_ids
+ :param external_task_group_id: The ID of the external task group
+ :param external_dag_id: The ID of the external DAG.
+ :param states: task or dag states
+ :param session: airflow session object
+ """
+ TI = TaskInstance
+ DR = DagRun
+ if not dttm_filter:
+ return 0
+
+ if external_task_ids:
+ count = (
+ session.scalar(
+ _count_query(TI, states, dttm_filter, external_dag_id,
session).filter(
+ TI.task_id.in_(external_task_ids)
+ )
+ )
+ ) / len(external_task_ids)
+ elif external_task_group_id:
+ external_task_group_task_ids = _get_external_task_group_task_ids(
+ dttm_filter, external_task_group_id, external_dag_id, session
+ )
+ if not external_task_group_task_ids:
+ count = 0
+ else:
+ count = (
+ session.scalar(
+ _count_query(TI, states, dttm_filter, external_dag_id,
session).filter(
+ tuple_in_condition((TI.task_id, TI.map_index),
external_task_group_task_ids)
+ )
+ )
+ ) / len(external_task_group_task_ids)
+ else:
+ count = session.scalar(_count_query(DR, states, dttm_filter,
external_dag_id, session))
+ return cast(int, count)
+
+
+def _count_query(model, states, dttm_filter, external_dag_id, session:
Session) -> Query:
+ """
+ Get the count of records against dttm filter and states.
+
+ :param model: The SQLAlchemy model representing the relevant table.
+ :param states: task or dag states
+ :param dttm_filter: date time filter for execution date
+ :param external_dag_id: The ID of the external DAG.
+ :param session: airflow session object
+ """
+ query = select(func.count()).filter(
+ model.dag_id == external_dag_id, model.state.in_(states),
model.execution_date.in_(dttm_filter)
+ )
+ return query
+
+
+def _get_external_task_group_task_ids(dttm_filter, external_task_group_id,
external_dag_id, session):
+ """
+ Get the count of records against dttm filter and states.
+
+ :param dttm_filter: date time filter for execution date
+ :param external_task_group_id: The ID of the external task group
+ :param external_dag_id: The ID of the external DAG.
+ :param session: airflow session object
+ """
+ refreshed_dag_info =
DagBag(read_dags_from_db=True).get_dag(external_dag_id, session)
+ task_group = refreshed_dag_info.task_group_dict.get(external_task_group_id)
+
+ if task_group:
+ group_tasks = session.scalars(
+ select(TaskInstance).filter(
+ TaskInstance.dag_id == external_dag_id,
+ TaskInstance.task_id.in_(task.task_id for task in task_group),
+ TaskInstance.execution_date.in_(dttm_filter),
+ )
+ )
+
+ return [(t.task_id, t.map_index) for t in group_tasks]
+
+ # returning default task_id as group_id itself, this will avoid any
failure in case of
+ # 'check_existence=False' and will fail on timeout
+ return [(external_task_group_id, -1)]
diff --git a/tests/sensors/test_external_task_sensor.py
b/tests/sensors/test_external_task_sensor.py
index bdbf7b1dc5..0330098f22 100644
--- a/tests/sensors/test_external_task_sensor.py
+++ b/tests/sensors/test_external_task_sensor.py
@@ -45,7 +45,7 @@ from airflow.sensors.external_task import (
)
from airflow.sensors.time_sensor import TimeSensor
from airflow.serialization.serialized_objects import SerializedBaseOperator
-from airflow.triggers.external_task import TaskStateTrigger
+from airflow.triggers.external_task import WorkflowTrigger
from airflow.utils.hashlib_wrapper import md5
from airflow.utils.session import create_session, provide_session
from airflow.utils.state import DagRunState, State, TaskInstanceState
@@ -996,7 +996,7 @@ class TestExternalTaskAsyncSensor:
with pytest.raises(TaskDeferred) as exc:
sensor.execute(context=mock.MagicMock())
- assert isinstance(exc.value.trigger, TaskStateTrigger), "Trigger is
not a TaskStateTrigger"
+ assert isinstance(exc.value.trigger, WorkflowTrigger), "Trigger is not
a WorkflowTrigger"
def test_defer_and_fire_failed_state_trigger(self):
"""Tests that an AirflowException is raised in case of error event"""
@@ -1041,7 +1041,7 @@ class TestExternalTaskAsyncSensor:
context=mock.MagicMock(),
event={"status": "success"},
)
- mock_log_info.assert_called_with("External task %s has executed
successfully.", EXTERNAL_TASK_ID)
+ mock_log_info.assert_called_with("External tasks %s has executed
successfully.", [EXTERNAL_TASK_ID])
def test_external_task_sensor_check_zipped_dag_existence(dag_zip_maker):
diff --git a/tests/triggers/test_external_task.py
b/tests/triggers/test_external_task.py
index f8295b331c..f60d4660d9 100644
--- a/tests/triggers/test_external_task.py
+++ b/tests/triggers/test_external_task.py
@@ -17,6 +17,7 @@
from __future__ import annotations
import asyncio
+from unittest import mock
import pytest
@@ -24,12 +25,64 @@ from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.operators.empty import EmptyOperator
-from airflow.triggers.external_task import DagStateTrigger, TaskStateTrigger
+from airflow.triggers.external_task import DagStateTrigger, TaskStateTrigger,
WorkflowTrigger
from airflow.utils import timezone
from airflow.utils.state import DagRunState, TaskInstanceState
from airflow.utils.timezone import utcnow
+class TestWorkflowTrigger:
+ DAG_ID = "external_task"
+ TASK_ID = "external_task_op"
+ RUN_ID = "external_task_run_id"
+ STATES = ["success", "fail"]
+
+ @mock.patch("airflow.triggers.external_task._get_count")
+ @mock.patch("asyncio.sleep")
+ @pytest.mark.asyncio
+ async def test_task_workflow_trigger(self, mock_sleep, mock_get_count):
+ """check the db count get called correctly."""
+ mock_get_count.return_value = 1
+ trigger = WorkflowTrigger(
+ external_dag_id=self.DAG_ID,
+ execution_dates=[timezone.datetime(2022, 1, 1)],
+ external_task_ids=[self.TASK_ID],
+ allowed_states=self.STATES,
+ poke_interval=0.2,
+ )
+
+ generator = trigger.run()
+ await generator.asend(None)
+ mock_get_count.assert_called_once_with(
+ [timezone.datetime(2022, 1, 1)], ["external_task_op"], None,
"external_task", ["success", "fail"]
+ )
+
+ def test_serialization(self):
+ """
+ Asserts that the WorkflowTrigger correctly serializes its arguments
and classpath.
+ """
+ trigger = WorkflowTrigger(
+ external_dag_id=self.DAG_ID,
+ execution_dates=[timezone.datetime(2022, 1, 1)],
+ external_task_ids=[self.TASK_ID],
+ allowed_states=self.STATES,
+ poke_interval=5,
+ )
+ classpath, kwargs = trigger.serialize()
+ assert classpath == "airflow.triggers.external_task.WorkflowTrigger"
+ assert kwargs == {
+ "external_dag_id": self.DAG_ID,
+ "execution_dates": [timezone.datetime(2022, 1, 1)],
+ "external_task_ids": [self.TASK_ID],
+ "external_task_group_id": None,
+ "failed_states": None,
+ "skipped_states": None,
+ "allowed_states": self.STATES,
+ "poke_interval": 5,
+ "soft_fail": False,
+ }
+
+
class TestTaskStateTrigger:
DAG_ID = "external_task"
TASK_ID = "external_task_op"