This is an automated email from the ASF dual-hosted git repository.

pankaj pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 688f91b330 Add deferrable mode to `BatchSensor`  (#30279)
688f91b330 is described below

commit 688f91b330addbc88a5bbda2f0e29cbed2313678
Author: Phani Kumar <[email protected]>
AuthorDate: Wed Jun 14 07:42:46 2023 +0530

    Add deferrable mode to `BatchSensor`  (#30279)
    
    * Implement BatchAsyncSensor
---
 airflow/providers/amazon/aws/sensors/batch.py      |  44 +++++++-
 airflow/providers/amazon/aws/triggers/batch.py     |  83 +++++++++++++++
 .../operators/batch.rst                            |   9 ++
 tests/providers/amazon/aws/sensors/test_batch.py   |  40 +++++++-
 tests/providers/amazon/aws/triggers/test_batch.py  | 114 ++++++++++++++++++++-
 5 files changed, 287 insertions(+), 3 deletions(-)

diff --git a/airflow/providers/amazon/aws/sensors/batch.py 
b/airflow/providers/amazon/aws/sensors/batch.py
index c93fc3d8b3..475b0ecb71 100644
--- a/airflow/providers/amazon/aws/sensors/batch.py
+++ b/airflow/providers/amazon/aws/sensors/batch.py
@@ -16,13 +16,15 @@
 # under the License.
 from __future__ import annotations
 
+from datetime import timedelta
 from functools import cached_property
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
 
 from deprecated import deprecated
 
 from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
+from airflow.providers.amazon.aws.triggers.batch import BatchSensorTrigger
 from airflow.sensors.base import BaseSensorOperator
 
 if TYPE_CHECKING:
@@ -41,6 +43,10 @@ class BatchSensor(BaseSensorOperator):
     :param job_id: Batch job_id to check the state for
     :param aws_conn_id: aws connection to use, defaults to 'aws_default'
     :param region_name: aws region name associated with the client
+    :param deferrable: Run sensor in the deferrable mode.
+    :param poke_interval: polling period in seconds to check for the status of 
the job.
+    :param max_retries: Number of times to poll for job state before
+        returning the current state.
     """
 
     template_fields: Sequence[str] = ("job_id",)
@@ -53,12 +59,18 @@ class BatchSensor(BaseSensorOperator):
         job_id: str,
         aws_conn_id: str = "aws_default",
         region_name: str | None = None,
+        deferrable: bool = False,
+        poke_interval: float = 5,
+        max_retries: int = 5,
         **kwargs,
     ):
         super().__init__(**kwargs)
         self.job_id = job_id
         self.aws_conn_id = aws_conn_id
         self.region_name = region_name
+        self.deferrable = deferrable
+        self.poke_interval = poke_interval
+        self.max_retries = max_retries
 
     def poke(self, context: Context) -> bool:
         job_description = self.hook.get_job_description(self.job_id)
@@ -75,6 +87,36 @@ class BatchSensor(BaseSensorOperator):
 
         raise AirflowException(f"Batch sensor failed. Unknown AWS Batch job 
status: {state}")
 
+    def execute(self, context: Context) -> None:
+        if not self.deferrable:
+            super().execute(context=context)
+        else:
+            timeout = (
+                timedelta(seconds=self.max_retries * self.poke_interval + 60)
+                if self.max_retries
+                else self.execution_timeout
+            )
+            self.defer(
+                timeout=timeout,
+                trigger=BatchSensorTrigger(
+                    job_id=self.job_id,
+                    aws_conn_id=self.aws_conn_id,
+                    region_name=self.region_name,
+                    poke_interval=self.poke_interval,
+                ),
+                method_name="execute_complete",
+            )
+
+    def execute_complete(self, context: Context, event: dict[str, Any]) -> 
None:
+        """
+        Callback for when the trigger fires - returns immediately.
+        Relies on trigger to throw an exception, otherwise it assumes 
execution was
+        successful.
+        """
+        if "status" in event and event["status"] == "failure":
+            raise AirflowException(event["message"])
+        self.log.info(event["message"])
+
     @deprecated(reason="use `hook` property instead.")
     def get_hook(self) -> BatchClientHook:
         """Create and return a BatchClientHook."""
diff --git a/airflow/providers/amazon/aws/triggers/batch.py 
b/airflow/providers/amazon/aws/triggers/batch.py
index dc858a80fd..f4a5de1525 100644
--- a/airflow/providers/amazon/aws/triggers/batch.py
+++ b/airflow/providers/amazon/aws/triggers/batch.py
@@ -105,3 +105,86 @@ class BatchOperatorTrigger(BaseTrigger):
             yield TriggerEvent({"status": "failure", "message": "Job Failed - 
max attempts reached."})
         else:
             yield TriggerEvent({"status": "success", "job_id": self.job_id})
+
+
+class BatchSensorTrigger(BaseTrigger):
+    """
+    Checks for the status of a submitted job_id to AWS Batch until it reaches 
a failure or a success state.
+    BatchSensorTrigger is fired as deferred class with params to poll the job 
state in Triggerer.
+
+    :param job_id: the job ID, to poll for job completion or not
+    :param region_name: AWS region name to use
+        Override the region_name in connection (if provided)
+    :param aws_conn_id: connection id of AWS credentials / region name. If 
None,
+        credential boto3 strategy will be used
+    :param poke_interval: polling period in seconds to check for the status of 
the job
+    """
+
+    def __init__(
+        self,
+        job_id: str,
+        region_name: str | None,
+        aws_conn_id: str | None = "aws_default",
+        poke_interval: float = 5,
+    ):
+        super().__init__()
+        self.job_id = job_id
+        self.aws_conn_id = aws_conn_id
+        self.region_name = region_name
+        self.poke_interval = poke_interval
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        """Serializes BatchSensorTrigger arguments and classpath."""
+        return (
+            "airflow.providers.amazon.aws.triggers.batch.BatchSensorTrigger",
+            {
+                "job_id": self.job_id,
+                "aws_conn_id": self.aws_conn_id,
+                "region_name": self.region_name,
+                "poke_interval": self.poke_interval,
+            },
+        )
+
+    @cached_property
+    def hook(self) -> BatchClientHook:
+        return BatchClientHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region_name)
+
+    async def run(self):
+        """
+        Make async connection using aiobotocore library to AWS Batch,
+        periodically poll for the Batch job status.
+
+        The status that indicates job completion are: 'SUCCEEDED'|'FAILED'.
+        """
+        async with self.hook.async_conn as client:
+            waiter = self.hook.get_waiter("batch_job_complete", 
deferrable=True, client=client)
+            attempt = 0
+            while True:
+                attempt = attempt + 1
+                try:
+                    await waiter.wait(
+                        jobs=[self.job_id],
+                        WaiterConfig={
+                            "Delay": int(self.poke_interval),
+                            "MaxAttempts": 1,
+                        },
+                    )
+                    break
+                except WaiterError as error:
+                    if "error" in str(error):
+                        yield TriggerEvent({"status": "failure", "message": 
f"Job Failed: {error}"})
+                        break
+                    self.log.info(
+                        "Job response is %s. Retrying attempt %s",
+                        error.last_response["Error"]["Message"],
+                        attempt,
+                    )
+                    await asyncio.sleep(int(self.poke_interval))
+
+            yield TriggerEvent(
+                {
+                    "status": "success",
+                    "job_id": self.job_id,
+                    "message": f"Job {self.job_id} Succeeded",
+                }
+            )
diff --git a/docs/apache-airflow-providers-amazon/operators/batch.rst 
b/docs/apache-airflow-providers-amazon/operators/batch.rst
index bcfb86dbf7..4cc2a2b0cc 100644
--- a/docs/apache-airflow-providers-amazon/operators/batch.rst
+++ b/docs/apache-airflow-providers-amazon/operators/batch.rst
@@ -77,6 +77,15 @@ use 
:class:`~airflow.providers.amazon.aws.sensors.batch.BatchSensor`.
     :start-after: [START howto_sensor_batch]
     :end-before: [END howto_sensor_batch]
 
+In order to monitor the state of the AWS Batch Job asynchronously, use
+:class:`~airflow.providers.amazon.aws.sensors.batch.BatchSensor` with the
+parameter ``deferrable`` set to True.
+
+Since this will release the Airflow worker slot , it will lead to efficient
+utilization of available resources on your Airflow deployment.
+This will also need the triggerer component to be available in your
+Airflow deployment.
+
 .. _howto/sensor:BatchComputeEnvironmentSensor:
 
 Wait on an AWS Batch compute environment status
diff --git a/tests/providers/amazon/aws/sensors/test_batch.py 
b/tests/providers/amazon/aws/sensors/test_batch.py
index 835b99ad0a..42e9bffb5b 100644
--- a/tests/providers/amazon/aws/sensors/test_batch.py
+++ b/tests/providers/amazon/aws/sensors/test_batch.py
@@ -20,16 +20,18 @@ from unittest import mock
 
 import pytest
 
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, TaskDeferred
 from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
 from airflow.providers.amazon.aws.sensors.batch import (
     BatchComputeEnvironmentSensor,
     BatchJobQueueSensor,
     BatchSensor,
 )
+from airflow.providers.amazon.aws.triggers.batch import BatchSensorTrigger
 
 TASK_ID = "batch_job_sensor"
 JOB_ID = "8222a1c2-b246-4e19-b1b8-0039bb4407c0"
+AWS_REGION = "eu-west-1"
 
 
 class TestBatchSensor:
@@ -195,3 +197,39 @@ class TestBatchJobQueueSensor:
             jobQueues=[self.job_queue],
         )
         assert "AWS Batch job queue failed" in str(ctx.value)
+
+
+class TestBatchAsyncSensor:
+    TASK = BatchSensor(task_id="task", job_id=JOB_ID, region_name=AWS_REGION, 
deferrable=True)
+
+    def test_batch_sensor_async(self):
+        """
+        Asserts that a task is deferred and a BatchSensorTrigger will be fired
+        when the BatchSensorAsync is executed.
+        """
+
+        with pytest.raises(TaskDeferred) as exc:
+            self.TASK.execute({})
+        assert isinstance(exc.value.trigger, BatchSensorTrigger), "Trigger is 
not a BatchSensorTrigger"
+
+    def test_batch_sensor_async_execute_failure(self):
+        """Tests that an AirflowException is raised in case of error event"""
+
+        with pytest.raises(AirflowException) as exc_info:
+            self.TASK.execute_complete(
+                context={}, event={"status": "failure", "message": "test 
failure message"}
+            )
+
+        assert str(exc_info.value) == "test failure message"
+
+    @pytest.mark.parametrize(
+        "event",
+        [{"status": "success", "message": f"AWS Batch job ({JOB_ID}) 
succeeded"}],
+    )
+    def test_batch_sensor_async_execute_complete(self, caplog, event):
+        """Tests that execute_complete method returns None and that it prints 
expected log"""
+
+        with mock.patch.object(self.TASK.log, "info") as mock_log_info:
+            assert self.TASK.execute_complete(context={}, event=event) is None
+
+        mock_log_info.assert_called_with(event["message"])
diff --git a/tests/providers/amazon/aws/triggers/test_batch.py 
b/tests/providers/amazon/aws/triggers/test_batch.py
index 6f87d92a2d..5cf125f828 100644
--- a/tests/providers/amazon/aws/triggers/test_batch.py
+++ b/tests/providers/amazon/aws/triggers/test_batch.py
@@ -20,8 +20,9 @@ from unittest import mock
 from unittest.mock import AsyncMock
 
 import pytest
+from botocore.exceptions import WaiterError
 
-from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger
+from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger, 
BatchSensorTrigger
 from airflow.triggers.base import TriggerEvent
 
 BATCH_JOB_ID = "job_id"
@@ -29,6 +30,7 @@ POLL_INTERVAL = 5
 MAX_ATTEMPT = 5
 AWS_CONN_ID = "aws_batch_job_conn"
 AWS_REGION = "us-east-2"
+pytest.importorskip("aiobotocore")
 
 
 class TestBatchOperatorTrigger:
@@ -69,3 +71,113 @@ class TestBatchOperatorTrigger:
         response = await generator.asend(None)
 
         assert response == TriggerEvent({"status": "success", "job_id": 
BATCH_JOB_ID})
+
+
+class TestBatchSensorTrigger:
+    TRIGGER = BatchSensorTrigger(
+        job_id=BATCH_JOB_ID,
+        region_name=AWS_REGION,
+        aws_conn_id=AWS_CONN_ID,
+        poke_interval=POLL_INTERVAL,
+    )
+
+    def test_batch_sensor_trigger_serialization(self):
+        """
+        Asserts that the BatchSensorTrigger correctly serializes its arguments
+        and classpath.
+        """
+
+        classpath, kwargs = self.TRIGGER.serialize()
+        assert classpath == 
"airflow.providers.amazon.aws.triggers.batch.BatchSensorTrigger"
+        assert kwargs == {
+            "job_id": BATCH_JOB_ID,
+            "region_name": AWS_REGION,
+            "aws_conn_id": AWS_CONN_ID,
+            "poke_interval": POLL_INTERVAL,
+        }
+
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_waiter")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.async_conn")
+    async def test_batch_job_trigger_run(self, mock_async_conn, 
mock_get_waiter):
+        the_mock = mock.MagicMock()
+        mock_async_conn.__aenter__.return_value = the_mock
+
+        mock_get_waiter().wait = AsyncMock()
+
+        batch_trigger = BatchOperatorTrigger(
+            job_id=BATCH_JOB_ID,
+            poll_interval=POLL_INTERVAL,
+            max_retries=MAX_ATTEMPT,
+            aws_conn_id=AWS_CONN_ID,
+            region_name=AWS_REGION,
+        )
+
+        generator = batch_trigger.run()
+        response = await generator.asend(None)
+
+        assert response == TriggerEvent({"status": "success", "job_id": 
BATCH_JOB_ID})
+
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_waiter")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.async_conn")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_job_description")
+    async def test_batch_sensor_trigger_completed(self, mock_response, 
mock_async_conn, mock_get_waiter):
+        """Test if the success event is returned from trigger."""
+        mock_response.return_value = {"status": "SUCCEEDED"}
+
+        the_mock = mock.MagicMock()
+        mock_async_conn.__aenter__.return_value = the_mock
+
+        mock_get_waiter().wait = AsyncMock()
+
+        trigger = BatchSensorTrigger(
+            job_id=BATCH_JOB_ID,
+            region_name=AWS_REGION,
+            aws_conn_id=AWS_CONN_ID,
+        )
+        generator = trigger.run()
+        actual_response = await generator.asend(None)
+        assert (
+            TriggerEvent(
+                {"status": "success", "job_id": BATCH_JOB_ID, "message": f"Job 
{BATCH_JOB_ID} Succeeded"}
+            )
+            == actual_response
+        )
+
+    @pytest.mark.asyncio
+    @mock.patch("asyncio.sleep")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_waiter")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_job_description")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.async_conn")
+    async def test_batch_sensor_trigger_failure(
+        self, mock_async_conn, mock_response, mock_get_waiter, mock_sleep
+    ):
+        """Test if the failure event is returned from trigger."""
+        a_mock = mock.MagicMock()
+        mock_async_conn.__aenter__.return_value = a_mock
+
+        mock_response.return_value = {"status": "failed"}
+
+        name = "batch_job_complete"
+        reason = (
+            "An error occurred (UnrecognizedClientException): The security 
token included in the "
+            "request is invalid. "
+        )
+        last_response = ({"Error": {"Message": "The security token included in 
the request is invalid."}},)
+
+        error_failed = WaiterError(
+            name=name,
+            reason=reason,
+            last_response=last_response,
+        )
+
+        mock_get_waiter().wait.side_effect = 
AsyncMock(side_effect=[error_failed])
+        mock_sleep.return_value = True
+
+        trigger = BatchSensorTrigger(job_id=BATCH_JOB_ID, 
region_name=AWS_REGION, aws_conn_id=AWS_CONN_ID)
+        generator = trigger.run()
+        actual_response = await generator.asend(None)
+        assert actual_response == TriggerEvent(
+            {"status": "failure", "message": f"Job Failed: Waiter {name} 
failed: {reason}"}
+        )

Reply via email to