This is an automated email from the ASF dual-hosted git repository.
pankaj pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 688f91b330 Add deferrable mode to `BatchSensor` (#30279)
688f91b330 is described below
commit 688f91b330addbc88a5bbda2f0e29cbed2313678
Author: Phani Kumar <[email protected]>
AuthorDate: Wed Jun 14 07:42:46 2023 +0530
Add deferrable mode to `BatchSensor` (#30279)
* Implement BatchAsyncSensor
---
airflow/providers/amazon/aws/sensors/batch.py | 44 +++++++-
airflow/providers/amazon/aws/triggers/batch.py | 83 +++++++++++++++
.../operators/batch.rst | 9 ++
tests/providers/amazon/aws/sensors/test_batch.py | 40 +++++++-
tests/providers/amazon/aws/triggers/test_batch.py | 114 ++++++++++++++++++++-
5 files changed, 287 insertions(+), 3 deletions(-)
diff --git a/airflow/providers/amazon/aws/sensors/batch.py
b/airflow/providers/amazon/aws/sensors/batch.py
index c93fc3d8b3..475b0ecb71 100644
--- a/airflow/providers/amazon/aws/sensors/batch.py
+++ b/airflow/providers/amazon/aws/sensors/batch.py
@@ -16,13 +16,15 @@
# under the License.
from __future__ import annotations
+from datetime import timedelta
from functools import cached_property
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
from deprecated import deprecated
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
+from airflow.providers.amazon.aws.triggers.batch import BatchSensorTrigger
from airflow.sensors.base import BaseSensorOperator
if TYPE_CHECKING:
@@ -41,6 +43,10 @@ class BatchSensor(BaseSensorOperator):
:param job_id: Batch job_id to check the state for
:param aws_conn_id: aws connection to use, defaults to 'aws_default'
:param region_name: aws region name associated with the client
+ :param deferrable: Run sensor in the deferrable mode.
+ :param poke_interval: polling period in seconds to check for the status of
the job.
+ :param max_retries: Number of times to poll for job state before
+ returning the current state.
"""
template_fields: Sequence[str] = ("job_id",)
@@ -53,12 +59,18 @@ class BatchSensor(BaseSensorOperator):
job_id: str,
aws_conn_id: str = "aws_default",
region_name: str | None = None,
+ deferrable: bool = False,
+ poke_interval: float = 5,
+ max_retries: int = 5,
**kwargs,
):
super().__init__(**kwargs)
self.job_id = job_id
self.aws_conn_id = aws_conn_id
self.region_name = region_name
+ self.deferrable = deferrable
+ self.poke_interval = poke_interval
+ self.max_retries = max_retries
def poke(self, context: Context) -> bool:
job_description = self.hook.get_job_description(self.job_id)
@@ -75,6 +87,36 @@ class BatchSensor(BaseSensorOperator):
raise AirflowException(f"Batch sensor failed. Unknown AWS Batch job
status: {state}")
+ def execute(self, context: Context) -> None:
+ if not self.deferrable:
+ super().execute(context=context)
+ else:
+ timeout = (
+ timedelta(seconds=self.max_retries * self.poke_interval + 60)
+ if self.max_retries
+ else self.execution_timeout
+ )
+ self.defer(
+ timeout=timeout,
+ trigger=BatchSensorTrigger(
+ job_id=self.job_id,
+ aws_conn_id=self.aws_conn_id,
+ region_name=self.region_name,
+ poke_interval=self.poke_interval,
+ ),
+ method_name="execute_complete",
+ )
+
+ def execute_complete(self, context: Context, event: dict[str, Any]) ->
None:
+ """
+ Callback for when the trigger fires - returns immediately.
+ Relies on trigger to throw an exception, otherwise it assumes
execution was
+ successful.
+ """
+ if "status" in event and event["status"] == "failure":
+ raise AirflowException(event["message"])
+ self.log.info(event["message"])
+
@deprecated(reason="use `hook` property instead.")
def get_hook(self) -> BatchClientHook:
"""Create and return a BatchClientHook."""
diff --git a/airflow/providers/amazon/aws/triggers/batch.py
b/airflow/providers/amazon/aws/triggers/batch.py
index dc858a80fd..f4a5de1525 100644
--- a/airflow/providers/amazon/aws/triggers/batch.py
+++ b/airflow/providers/amazon/aws/triggers/batch.py
@@ -105,3 +105,86 @@ class BatchOperatorTrigger(BaseTrigger):
yield TriggerEvent({"status": "failure", "message": "Job Failed -
max attempts reached."})
else:
yield TriggerEvent({"status": "success", "job_id": self.job_id})
+
+
+class BatchSensorTrigger(BaseTrigger):
+ """
+ Checks for the status of a submitted job_id to AWS Batch until it reaches
a failure or a success state.
+ BatchSensorTrigger is fired as deferred class with params to poll the job
state in Triggerer.
+
+ :param job_id: the job ID, to poll for job completion or not
+ :param region_name: AWS region name to use
+ Override the region_name in connection (if provided)
+ :param aws_conn_id: connection id of AWS credentials / region name. If
None,
+ credential boto3 strategy will be used
+ :param poke_interval: polling period in seconds to check for the status of
the job
+ """
+
+ def __init__(
+ self,
+ job_id: str,
+ region_name: str | None,
+ aws_conn_id: str | None = "aws_default",
+ poke_interval: float = 5,
+ ):
+ super().__init__()
+ self.job_id = job_id
+ self.aws_conn_id = aws_conn_id
+ self.region_name = region_name
+ self.poke_interval = poke_interval
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serializes BatchSensorTrigger arguments and classpath."""
+ return (
+ "airflow.providers.amazon.aws.triggers.batch.BatchSensorTrigger",
+ {
+ "job_id": self.job_id,
+ "aws_conn_id": self.aws_conn_id,
+ "region_name": self.region_name,
+ "poke_interval": self.poke_interval,
+ },
+ )
+
+ @cached_property
+ def hook(self) -> BatchClientHook:
+ return BatchClientHook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name)
+
+ async def run(self):
+ """
+ Make async connection using aiobotocore library to AWS Batch,
+ periodically poll for the Batch job status.
+
+ The status that indicates job completion are: 'SUCCEEDED'|'FAILED'.
+ """
+ async with self.hook.async_conn as client:
+ waiter = self.hook.get_waiter("batch_job_complete",
deferrable=True, client=client)
+ attempt = 0
+ while True:
+ attempt = attempt + 1
+ try:
+ await waiter.wait(
+ jobs=[self.job_id],
+ WaiterConfig={
+ "Delay": int(self.poke_interval),
+ "MaxAttempts": 1,
+ },
+ )
+ break
+ except WaiterError as error:
+ if "error" in str(error):
+ yield TriggerEvent({"status": "failure", "message":
f"Job Failed: {error}"})
+ break
+ self.log.info(
+ "Job response is %s. Retrying attempt %s",
+ error.last_response["Error"]["Message"],
+ attempt,
+ )
+ await asyncio.sleep(int(self.poke_interval))
+
+ yield TriggerEvent(
+ {
+ "status": "success",
+ "job_id": self.job_id,
+ "message": f"Job {self.job_id} Succeeded",
+ }
+ )
diff --git a/docs/apache-airflow-providers-amazon/operators/batch.rst
b/docs/apache-airflow-providers-amazon/operators/batch.rst
index bcfb86dbf7..4cc2a2b0cc 100644
--- a/docs/apache-airflow-providers-amazon/operators/batch.rst
+++ b/docs/apache-airflow-providers-amazon/operators/batch.rst
@@ -77,6 +77,15 @@ use
:class:`~airflow.providers.amazon.aws.sensors.batch.BatchSensor`.
:start-after: [START howto_sensor_batch]
:end-before: [END howto_sensor_batch]
+In order to monitor the state of the AWS Batch Job asynchronously, use
+:class:`~airflow.providers.amazon.aws.sensors.batch.BatchSensor` with the
+parameter ``deferrable`` set to True.
+
+Since this will release the Airflow worker slot , it will lead to efficient
+utilization of available resources on your Airflow deployment.
+This will also need the triggerer component to be available in your
+Airflow deployment.
+
.. _howto/sensor:BatchComputeEnvironmentSensor:
Wait on an AWS Batch compute environment status
diff --git a/tests/providers/amazon/aws/sensors/test_batch.py
b/tests/providers/amazon/aws/sensors/test_batch.py
index 835b99ad0a..42e9bffb5b 100644
--- a/tests/providers/amazon/aws/sensors/test_batch.py
+++ b/tests/providers/amazon/aws/sensors/test_batch.py
@@ -20,16 +20,18 @@ from unittest import mock
import pytest
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
from airflow.providers.amazon.aws.sensors.batch import (
BatchComputeEnvironmentSensor,
BatchJobQueueSensor,
BatchSensor,
)
+from airflow.providers.amazon.aws.triggers.batch import BatchSensorTrigger
TASK_ID = "batch_job_sensor"
JOB_ID = "8222a1c2-b246-4e19-b1b8-0039bb4407c0"
+AWS_REGION = "eu-west-1"
class TestBatchSensor:
@@ -195,3 +197,39 @@ class TestBatchJobQueueSensor:
jobQueues=[self.job_queue],
)
assert "AWS Batch job queue failed" in str(ctx.value)
+
+
+class TestBatchAsyncSensor:
+ TASK = BatchSensor(task_id="task", job_id=JOB_ID, region_name=AWS_REGION,
deferrable=True)
+
+ def test_batch_sensor_async(self):
+ """
+ Asserts that a task is deferred and a BatchSensorTrigger will be fired
+ when the BatchSensorAsync is executed.
+ """
+
+ with pytest.raises(TaskDeferred) as exc:
+ self.TASK.execute({})
+ assert isinstance(exc.value.trigger, BatchSensorTrigger), "Trigger is
not a BatchSensorTrigger"
+
+ def test_batch_sensor_async_execute_failure(self):
+ """Tests that an AirflowException is raised in case of error event"""
+
+ with pytest.raises(AirflowException) as exc_info:
+ self.TASK.execute_complete(
+ context={}, event={"status": "failure", "message": "test
failure message"}
+ )
+
+ assert str(exc_info.value) == "test failure message"
+
+ @pytest.mark.parametrize(
+ "event",
+ [{"status": "success", "message": f"AWS Batch job ({JOB_ID})
succeeded"}],
+ )
+ def test_batch_sensor_async_execute_complete(self, caplog, event):
+ """Tests that execute_complete method returns None and that it prints
expected log"""
+
+ with mock.patch.object(self.TASK.log, "info") as mock_log_info:
+ assert self.TASK.execute_complete(context={}, event=event) is None
+
+ mock_log_info.assert_called_with(event["message"])
diff --git a/tests/providers/amazon/aws/triggers/test_batch.py
b/tests/providers/amazon/aws/triggers/test_batch.py
index 6f87d92a2d..5cf125f828 100644
--- a/tests/providers/amazon/aws/triggers/test_batch.py
+++ b/tests/providers/amazon/aws/triggers/test_batch.py
@@ -20,8 +20,9 @@ from unittest import mock
from unittest.mock import AsyncMock
import pytest
+from botocore.exceptions import WaiterError
-from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger
+from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger,
BatchSensorTrigger
from airflow.triggers.base import TriggerEvent
BATCH_JOB_ID = "job_id"
@@ -29,6 +30,7 @@ POLL_INTERVAL = 5
MAX_ATTEMPT = 5
AWS_CONN_ID = "aws_batch_job_conn"
AWS_REGION = "us-east-2"
+pytest.importorskip("aiobotocore")
class TestBatchOperatorTrigger:
@@ -69,3 +71,113 @@ class TestBatchOperatorTrigger:
response = await generator.asend(None)
assert response == TriggerEvent({"status": "success", "job_id":
BATCH_JOB_ID})
+
+
+class TestBatchSensorTrigger:
+ TRIGGER = BatchSensorTrigger(
+ job_id=BATCH_JOB_ID,
+ region_name=AWS_REGION,
+ aws_conn_id=AWS_CONN_ID,
+ poke_interval=POLL_INTERVAL,
+ )
+
+ def test_batch_sensor_trigger_serialization(self):
+ """
+ Asserts that the BatchSensorTrigger correctly serializes its arguments
+ and classpath.
+ """
+
+ classpath, kwargs = self.TRIGGER.serialize()
+ assert classpath ==
"airflow.providers.amazon.aws.triggers.batch.BatchSensorTrigger"
+ assert kwargs == {
+ "job_id": BATCH_JOB_ID,
+ "region_name": AWS_REGION,
+ "aws_conn_id": AWS_CONN_ID,
+ "poke_interval": POLL_INTERVAL,
+ }
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_waiter")
+
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.async_conn")
+ async def test_batch_job_trigger_run(self, mock_async_conn,
mock_get_waiter):
+ the_mock = mock.MagicMock()
+ mock_async_conn.__aenter__.return_value = the_mock
+
+ mock_get_waiter().wait = AsyncMock()
+
+ batch_trigger = BatchOperatorTrigger(
+ job_id=BATCH_JOB_ID,
+ poll_interval=POLL_INTERVAL,
+ max_retries=MAX_ATTEMPT,
+ aws_conn_id=AWS_CONN_ID,
+ region_name=AWS_REGION,
+ )
+
+ generator = batch_trigger.run()
+ response = await generator.asend(None)
+
+ assert response == TriggerEvent({"status": "success", "job_id":
BATCH_JOB_ID})
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_waiter")
+
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.async_conn")
+
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_job_description")
+ async def test_batch_sensor_trigger_completed(self, mock_response,
mock_async_conn, mock_get_waiter):
+ """Test if the success event is returned from trigger."""
+ mock_response.return_value = {"status": "SUCCEEDED"}
+
+ the_mock = mock.MagicMock()
+ mock_async_conn.__aenter__.return_value = the_mock
+
+ mock_get_waiter().wait = AsyncMock()
+
+ trigger = BatchSensorTrigger(
+ job_id=BATCH_JOB_ID,
+ region_name=AWS_REGION,
+ aws_conn_id=AWS_CONN_ID,
+ )
+ generator = trigger.run()
+ actual_response = await generator.asend(None)
+ assert (
+ TriggerEvent(
+ {"status": "success", "job_id": BATCH_JOB_ID, "message": f"Job
{BATCH_JOB_ID} Succeeded"}
+ )
+ == actual_response
+ )
+
+ @pytest.mark.asyncio
+ @mock.patch("asyncio.sleep")
+
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_waiter")
+
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_job_description")
+
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.async_conn")
+ async def test_batch_sensor_trigger_failure(
+ self, mock_async_conn, mock_response, mock_get_waiter, mock_sleep
+ ):
+ """Test if the failure event is returned from trigger."""
+ a_mock = mock.MagicMock()
+ mock_async_conn.__aenter__.return_value = a_mock
+
+ mock_response.return_value = {"status": "failed"}
+
+ name = "batch_job_complete"
+ reason = (
+ "An error occurred (UnrecognizedClientException): The security
token included in the "
+ "request is invalid. "
+ )
+ last_response = ({"Error": {"Message": "The security token included in
the request is invalid."}},)
+
+ error_failed = WaiterError(
+ name=name,
+ reason=reason,
+ last_response=last_response,
+ )
+
+ mock_get_waiter().wait.side_effect =
AsyncMock(side_effect=[error_failed])
+ mock_sleep.return_value = True
+
+ trigger = BatchSensorTrigger(job_id=BATCH_JOB_ID,
region_name=AWS_REGION, aws_conn_id=AWS_CONN_ID)
+ generator = trigger.run()
+ actual_response = await generator.asend(None)
+ assert actual_response == TriggerEvent(
+ {"status": "failure", "message": f"Job Failed: Waiter {name}
failed: {reason}"}
+ )