This is an automated email from the ASF dual-hosted git repository.

phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e9a4bcaa16 Improve ExternalTaskSensor Async Implementation (#36916)
e9a4bcaa16 is described below

commit e9a4bcaa16cd7627c22469d9aec3fadceb3361ac
Author: Pankaj Singh <[email protected]>
AuthorDate: Thu Jan 25 16:05:51 2024 +0530

    Improve ExternalTaskSensor Async Implementation (#36916)
---
 airflow/sensors/external_task.py           | 102 +++++++++---------------
 airflow/triggers/external_task.py          |  99 +++++++++++++++++++++++
 airflow/utils/sensor_helper.py             | 123 +++++++++++++++++++++++++++++
 tests/sensors/test_external_task_sensor.py |   6 +-
 tests/triggers/test_external_task.py       |  55 ++++++++++++-
 5 files changed, 315 insertions(+), 70 deletions(-)

diff --git a/airflow/sensors/external_task.py b/airflow/sensors/external_task.py
index 5a3353d916..32f5992001 100644
--- a/airflow/sensors/external_task.py
+++ b/airflow/sensors/external_task.py
@@ -23,27 +23,24 @@ import warnings
 from typing import TYPE_CHECKING, Any, Callable, Collection, Iterable
 
 import attr
-from sqlalchemy import func
 
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException, AirflowSkipException, 
RemovedInAirflow3Warning
 from airflow.models.baseoperatorlink import BaseOperatorLink
 from airflow.models.dag import DagModel
 from airflow.models.dagbag import DagBag
-from airflow.models.dagrun import DagRun
 from airflow.models.taskinstance import TaskInstance
 from airflow.operators.empty import EmptyOperator
 from airflow.sensors.base import BaseSensorOperator
-from airflow.triggers.external_task import TaskStateTrigger
+from airflow.triggers.external_task import WorkflowTrigger
 from airflow.utils.file import correct_maybe_zipped
 from airflow.utils.helpers import build_airflow_url_with_query
+from airflow.utils.sensor_helper import _get_count, 
_get_external_task_group_task_ids
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import tuple_in_condition
 from airflow.utils.state import State, TaskInstanceState
-from airflow.utils.timezone import utcnow
 
 if TYPE_CHECKING:
-    from sqlalchemy.orm import Query, Session
+    from sqlalchemy.orm import Session
 
     from airflow.models.baseoperator import BaseOperator
     from airflow.models.taskinstancekey import TaskInstanceKey
@@ -351,13 +348,14 @@ class ExternalTaskSensor(BaseSensorOperator):
             super().execute(context)
         else:
             self.defer(
-                trigger=TaskStateTrigger(
-                    dag_id=self.external_dag_id,
-                    task_id=self.external_task_id,
+                timeout=self.execution_timeout,
+                trigger=WorkflowTrigger(
+                    external_dag_id=self.external_dag_id,
+                    external_task_ids=self.external_task_ids,
                     execution_dates=self._get_dttm_filter(context),
-                    states=self.allowed_states,
-                    trigger_start_time=utcnow(),
-                    poll_interval=self.poll_interval,
+                    allowed_states=self.allowed_states,
+                    poke_interval=self.poll_interval,
+                    soft_fail=self.soft_fail,
                 ),
                 method_name="execute_complete",
             )
@@ -365,15 +363,17 @@ class ExternalTaskSensor(BaseSensorOperator):
     def execute_complete(self, context, event=None):
         """Execute when the trigger fires - return immediately."""
         if event["status"] == "success":
-            self.log.info("External task %s has executed successfully.", 
self.external_task_id)
-            return None
-        elif event["status"] == "timeout":
-            raise AirflowException("Dag was not started within 1 minute, 
assuming fail.")
+            self.log.info("External tasks %s has executed successfully.", 
self.external_task_ids)
+        elif event["status"] == "skipped":
+            raise AirflowSkipException("External job has skipped skipping.")
         else:
-            raise AirflowException(
-                "Error occurred while trying to retrieve task status. Please, 
check the "
-                "name of executed task and Dag."
-            )
+            if self.soft_fail:
+                raise AirflowSkipException("External job has failed skipping.")
+            else:
+                raise AirflowException(
+                    "Error occurred while trying to retrieve task status. 
Please, check the "
+                    "name of executed task and Dag."
+                )
 
     def _check_for_existence(self, session) -> None:
         dag_to_wait = DagModel.get_current(self.external_dag_id, session)
@@ -412,55 +412,25 @@ class ExternalTaskSensor(BaseSensorOperator):
         :param states: task or dag states
         :return: count of record against the filters
         """
-        TI = TaskInstance
-        DR = DagRun
-        if not dttm_filter:
-            return 0
-
-        if self.external_task_ids:
-            count = (
-                self._count_query(TI, session, states, dttm_filter)
-                .filter(TI.task_id.in_(self.external_task_ids))
-                .scalar()
-            ) / len(self.external_task_ids)
-        elif self.external_task_group_id:
-            external_task_group_task_ids = 
self.get_external_task_group_task_ids(session, dttm_filter)
-            if not external_task_group_task_ids:
-                count = 0
-            else:
-                count = (
-                    self._count_query(TI, session, states, dttm_filter)
-                    .filter(tuple_in_condition((TI.task_id, TI.map_index), 
external_task_group_task_ids))
-                    .scalar()
-                ) / len(external_task_group_task_ids)
-        else:
-            count = self._count_query(DR, session, states, 
dttm_filter).scalar()
-        return count
-
-    def _count_query(self, model, session, states, dttm_filter) -> Query:
-        query = session.query(func.count()).filter(
-            model.dag_id == self.external_dag_id,
-            model.state.in_(states),
-            model.execution_date.in_(dttm_filter),
+        warnings.warn(
+            "This method is deprecated and will be removed in future.", 
DeprecationWarning, stacklevel=2
+        )
+        return _get_count(
+            dttm_filter,
+            self.external_task_ids,
+            self.external_task_group_id,
+            self.external_dag_id,
+            states,
+            session,
         )
-        return query
 
     def get_external_task_group_task_ids(self, session, dttm_filter):
-        refreshed_dag_info = 
DagBag(read_dags_from_db=True).get_dag(self.external_dag_id, session)
-        task_group = 
refreshed_dag_info.task_group_dict.get(self.external_task_group_id)
-
-        if task_group:
-            group_tasks = session.query(TaskInstance).filter(
-                TaskInstance.dag_id == self.external_dag_id,
-                TaskInstance.task_id.in_(task.task_id for task in task_group),
-                TaskInstance.execution_date.in_(dttm_filter),
-            )
-
-            return [(t.task_id, t.map_index) for t in group_tasks]
-
-        # returning default task_id as group_id itself, this will avoid any 
failure in case of
-        # 'check_existence=False' and will fail on timeout
-        return [(self.external_task_group_id, -1)]
+        warnings.warn(
+            "This method is deprecated and will be removed in future.", 
DeprecationWarning, stacklevel=2
+        )
+        return _get_external_task_group_task_ids(
+            dttm_filter, self.external_task_group_id, self.external_dag_id, 
session
+        )
 
     def _handle_execution_date_fn(self, context) -> Any:
         """
diff --git a/airflow/triggers/external_task.py 
b/airflow/triggers/external_task.py
index 269dbaebc8..98305ea4f1 100644
--- a/airflow/triggers/external_task.py
+++ b/airflow/triggers/external_task.py
@@ -18,12 +18,14 @@ from __future__ import annotations
 
 import asyncio
 import typing
+from typing import Any
 
 from asgiref.sync import sync_to_async
 from sqlalchemy import func
 
 from airflow.models import DagRun, TaskInstance
 from airflow.triggers.base import BaseTrigger, TriggerEvent
+from airflow.utils.sensor_helper import _get_count
 from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.state import TaskInstanceState
 from airflow.utils.timezone import utcnow
@@ -36,6 +38,103 @@ if typing.TYPE_CHECKING:
     from airflow.utils.state import DagRunState
 
 
+class WorkflowTrigger(BaseTrigger):
+    """
+    A trigger to monitor tasks, task group and dag execution in Apache Airflow.
+
+    :param external_dag_id: The ID of the external DAG.
+    :param execution_dates: A list of execution dates for the external DAG.
+    :param external_task_ids: A collection of external task IDs to wait for.
+    :param external_task_group_id: The ID of the external task group to wait 
for.
+    :param failed_states: States considered as failed for external tasks.
+    :param skipped_states: States considered as skipped for external tasks.
+    :param allowed_states: States considered as successful for external tasks.
+    :param poke_interval: The interval (in seconds) for poking the external 
tasks.
+    :param soft_fail: If True, the trigger will not fail the entire DAG on 
external task failure.
+    """
+
+    def __init__(
+        self,
+        external_dag_id: str,
+        execution_dates: list,
+        external_task_ids: typing.Collection[str] | None = None,
+        external_task_group_id: str | None = None,
+        failed_states: typing.Iterable[str] | None = None,
+        skipped_states: typing.Iterable[str] | None = None,
+        allowed_states: typing.Iterable[str] | None = None,
+        poke_interval: float = 2.0,
+        soft_fail: bool = False,
+        **kwargs,
+    ):
+        self.external_dag_id = external_dag_id
+        self.external_task_ids = external_task_ids
+        self.external_task_group_id = external_task_group_id
+        self.failed_states = failed_states
+        self.skipped_states = skipped_states
+        self.allowed_states = allowed_states
+        self.execution_dates = execution_dates
+        self.poke_interval = poke_interval
+        self.soft_fail = soft_fail
+        super().__init__(**kwargs)
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        """Serialize the trigger param and module path."""
+        return (
+            "airflow.triggers.external_task.WorkflowTrigger",
+            {
+                "external_dag_id": self.external_dag_id,
+                "external_task_ids": self.external_task_ids,
+                "external_task_group_id": self.external_task_group_id,
+                "failed_states": self.failed_states,
+                "skipped_states": self.skipped_states,
+                "allowed_states": self.allowed_states,
+                "execution_dates": self.execution_dates,
+                "poke_interval": self.poke_interval,
+                "soft_fail": self.soft_fail,
+            },
+        )
+
+    async def run(self) -> typing.AsyncIterator[TriggerEvent]:
+        """Check periodically tasks, task group or dag status."""
+        while True:
+            if self.failed_states:
+                failed_count = _get_count(
+                    self.execution_dates,
+                    self.external_task_ids,
+                    self.external_task_group_id,
+                    self.external_dag_id,
+                    self.failed_states,
+                )
+                if failed_count > 0:
+                    yield TriggerEvent({"status": "failed"})
+                    return
+                else:
+                    yield TriggerEvent({"status": "success"})
+                    return
+            if self.skipped_states:
+                skipped_count = _get_count(
+                    self.execution_dates,
+                    self.external_task_ids,
+                    self.external_task_group_id,
+                    self.external_dag_id,
+                    self.skipped_states,
+                )
+                if skipped_count > 0:
+                    yield TriggerEvent({"status": "skipped"})
+            allowed_count = _get_count(
+                self.execution_dates,
+                self.external_task_ids,
+                self.external_task_group_id,
+                self.external_dag_id,
+                self.allowed_states,
+            )
+            if allowed_count == len(self.execution_dates):
+                yield TriggerEvent({"status": "success"})
+                return
+            self.log.info("Sleeping for %s seconds", self.poke_interval)
+            await asyncio.sleep(self.poke_interval)
+
+
 class TaskStateTrigger(BaseTrigger):
     """
     Waits asynchronously for a task in a different DAG to complete for a 
specific logical date.
diff --git a/airflow/utils/sensor_helper.py b/airflow/utils/sensor_helper.py
new file mode 100644
index 0000000000..fe72a70041
--- /dev/null
+++ b/airflow/utils/sensor_helper.py
@@ -0,0 +1,123 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, cast
+
+from sqlalchemy import func, select
+
+from airflow.models import DagBag, DagRun, TaskInstance
+from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.utils.sqlalchemy import tuple_in_condition
+
+if TYPE_CHECKING:
+    from sqlalchemy.orm import Query, Session
+
+
+@provide_session
+def _get_count(
+    dttm_filter,
+    external_task_ids,
+    external_task_group_id,
+    external_dag_id,
+    states,
+    session: Session = NEW_SESSION,
+) -> int:
+    """
+    Get the count of records against dttm filter and states.
+
+    :param dttm_filter: date time filter for execution date
+    :param external_task_ids: The list of task_ids
+    :param external_task_group_id: The ID of the external task group
+    :param external_dag_id: The ID of the external DAG.
+    :param states: task or dag states
+    :param session: airflow session object
+    """
+    TI = TaskInstance
+    DR = DagRun
+    if not dttm_filter:
+        return 0
+
+    if external_task_ids:
+        count = (
+            session.scalar(
+                _count_query(TI, states, dttm_filter, external_dag_id, 
session).filter(
+                    TI.task_id.in_(external_task_ids)
+                )
+            )
+        ) / len(external_task_ids)
+    elif external_task_group_id:
+        external_task_group_task_ids = _get_external_task_group_task_ids(
+            dttm_filter, external_task_group_id, external_dag_id, session
+        )
+        if not external_task_group_task_ids:
+            count = 0
+        else:
+            count = (
+                session.scalar(
+                    _count_query(TI, states, dttm_filter, external_dag_id, 
session).filter(
+                        tuple_in_condition((TI.task_id, TI.map_index), 
external_task_group_task_ids)
+                    )
+                )
+            ) / len(external_task_group_task_ids)
+    else:
+        count = session.scalar(_count_query(DR, states, dttm_filter, 
external_dag_id, session))
+    return cast(int, count)
+
+
+def _count_query(model, states, dttm_filter, external_dag_id, session: 
Session) -> Query:
+    """
+    Get the count of records against dttm filter and states.
+
+    :param model: The SQLAlchemy model representing the relevant table.
+    :param states: task or dag states
+    :param dttm_filter: date time filter for execution date
+    :param external_dag_id: The ID of the external DAG.
+    :param session: airflow session object
+    """
+    query = select(func.count()).filter(
+        model.dag_id == external_dag_id, model.state.in_(states), 
model.execution_date.in_(dttm_filter)
+    )
+    return query
+
+
+def _get_external_task_group_task_ids(dttm_filter, external_task_group_id, 
external_dag_id, session):
+    """
+    Get the count of records against dttm filter and states.
+
+    :param dttm_filter: date time filter for execution date
+    :param external_task_group_id: The ID of the external task group
+    :param external_dag_id: The ID of the external DAG.
+    :param session: airflow session object
+    """
+    refreshed_dag_info = 
DagBag(read_dags_from_db=True).get_dag(external_dag_id, session)
+    task_group = refreshed_dag_info.task_group_dict.get(external_task_group_id)
+
+    if task_group:
+        group_tasks = session.scalars(
+            select(TaskInstance).filter(
+                TaskInstance.dag_id == external_dag_id,
+                TaskInstance.task_id.in_(task.task_id for task in task_group),
+                TaskInstance.execution_date.in_(dttm_filter),
+            )
+        )
+
+        return [(t.task_id, t.map_index) for t in group_tasks]
+
+    # returning default task_id as group_id itself, this will avoid any 
failure in case of
+    # 'check_existence=False' and will fail on timeout
+    return [(external_task_group_id, -1)]
diff --git a/tests/sensors/test_external_task_sensor.py 
b/tests/sensors/test_external_task_sensor.py
index bdbf7b1dc5..0330098f22 100644
--- a/tests/sensors/test_external_task_sensor.py
+++ b/tests/sensors/test_external_task_sensor.py
@@ -45,7 +45,7 @@ from airflow.sensors.external_task import (
 )
 from airflow.sensors.time_sensor import TimeSensor
 from airflow.serialization.serialized_objects import SerializedBaseOperator
-from airflow.triggers.external_task import TaskStateTrigger
+from airflow.triggers.external_task import WorkflowTrigger
 from airflow.utils.hashlib_wrapper import md5
 from airflow.utils.session import create_session, provide_session
 from airflow.utils.state import DagRunState, State, TaskInstanceState
@@ -996,7 +996,7 @@ class TestExternalTaskAsyncSensor:
         with pytest.raises(TaskDeferred) as exc:
             sensor.execute(context=mock.MagicMock())
 
-        assert isinstance(exc.value.trigger, TaskStateTrigger), "Trigger is 
not a TaskStateTrigger"
+        assert isinstance(exc.value.trigger, WorkflowTrigger), "Trigger is not 
a WorkflowTrigger"
 
     def test_defer_and_fire_failed_state_trigger(self):
         """Tests that an AirflowException is raised in case of error event"""
@@ -1041,7 +1041,7 @@ class TestExternalTaskAsyncSensor:
                 context=mock.MagicMock(),
                 event={"status": "success"},
             )
-        mock_log_info.assert_called_with("External task %s has executed 
successfully.", EXTERNAL_TASK_ID)
+        mock_log_info.assert_called_with("External tasks %s has executed 
successfully.", [EXTERNAL_TASK_ID])
 
 
 def test_external_task_sensor_check_zipped_dag_existence(dag_zip_maker):
diff --git a/tests/triggers/test_external_task.py 
b/tests/triggers/test_external_task.py
index f8295b331c..f60d4660d9 100644
--- a/tests/triggers/test_external_task.py
+++ b/tests/triggers/test_external_task.py
@@ -17,6 +17,7 @@
 from __future__ import annotations
 
 import asyncio
+from unittest import mock
 
 import pytest
 
@@ -24,12 +25,64 @@ from airflow.models.dag import DAG
 from airflow.models.dagrun import DagRun
 from airflow.models.taskinstance import TaskInstance
 from airflow.operators.empty import EmptyOperator
-from airflow.triggers.external_task import DagStateTrigger, TaskStateTrigger
+from airflow.triggers.external_task import DagStateTrigger, TaskStateTrigger, 
WorkflowTrigger
 from airflow.utils import timezone
 from airflow.utils.state import DagRunState, TaskInstanceState
 from airflow.utils.timezone import utcnow
 
 
+class TestWorkflowTrigger:
+    DAG_ID = "external_task"
+    TASK_ID = "external_task_op"
+    RUN_ID = "external_task_run_id"
+    STATES = ["success", "fail"]
+
+    @mock.patch("airflow.triggers.external_task._get_count")
+    @mock.patch("asyncio.sleep")
+    @pytest.mark.asyncio
+    async def test_task_workflow_trigger(self, mock_sleep, mock_get_count):
+        """check the db count get called correctly."""
+        mock_get_count.return_value = 1
+        trigger = WorkflowTrigger(
+            external_dag_id=self.DAG_ID,
+            execution_dates=[timezone.datetime(2022, 1, 1)],
+            external_task_ids=[self.TASK_ID],
+            allowed_states=self.STATES,
+            poke_interval=0.2,
+        )
+
+        generator = trigger.run()
+        await generator.asend(None)
+        mock_get_count.assert_called_once_with(
+            [timezone.datetime(2022, 1, 1)], ["external_task_op"], None, 
"external_task", ["success", "fail"]
+        )
+
+    def test_serialization(self):
+        """
+        Asserts that the WorkflowTrigger correctly serializes its arguments 
and classpath.
+        """
+        trigger = WorkflowTrigger(
+            external_dag_id=self.DAG_ID,
+            execution_dates=[timezone.datetime(2022, 1, 1)],
+            external_task_ids=[self.TASK_ID],
+            allowed_states=self.STATES,
+            poke_interval=5,
+        )
+        classpath, kwargs = trigger.serialize()
+        assert classpath == "airflow.triggers.external_task.WorkflowTrigger"
+        assert kwargs == {
+            "external_dag_id": self.DAG_ID,
+            "execution_dates": [timezone.datetime(2022, 1, 1)],
+            "external_task_ids": [self.TASK_ID],
+            "external_task_group_id": None,
+            "failed_states": None,
+            "skipped_states": None,
+            "allowed_states": self.STATES,
+            "poke_interval": 5,
+            "soft_fail": False,
+        }
+
+
 class TestTaskStateTrigger:
     DAG_ID = "external_task"
     TASK_ID = "external_task_op"

Reply via email to