This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 77c272e6e8 Add AWS deferrable BatchOperator (#29300)
77c272e6e8 is described below
commit 77c272e6e8ecda0ce48917064e58ba14f6a15844
Author: Rajath <[email protected]>
AuthorDate: Wed Apr 5 20:27:13 2023 +0530
Add AWS deferrable BatchOperator (#29300)
This PR donates the following BatchOperator deferrable developed in
[astronomer-providers](https://github.com/astronomer/astronomer-providers) repo
to apache airflow.
---
airflow/providers/amazon/aws/hooks/batch_client.py | 239 ++++++++++++++++++++-
airflow/providers/amazon/aws/operators/batch.py | 37 ++++
airflow/providers/amazon/aws/triggers/batch.py | 123 +++++++++++
.../operators/batch.rst | 5 +-
.../aws/deferrable/hooks/test_batch_client.py | 213 ++++++++++++++++++
.../amazon/aws/deferrable/triggers/test_batch.py | 131 +++++++++++
.../amazon/aws/hooks/test_batch_client.py | 1 +
tests/providers/amazon/aws/operators/test_batch.py | 120 ++++++++++-
8 files changed, 866 insertions(+), 3 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/batch_client.py
b/airflow/providers/amazon/aws/hooks/batch_client.py
index 526ab9a8a4..10b93afbbc 100644
--- a/airflow/providers/amazon/aws/hooks/batch_client.py
+++ b/airflow/providers/amazon/aws/hooks/batch_client.py
@@ -26,15 +26,17 @@ A client for AWS Batch services
"""
from __future__ import annotations
+import asyncio
from random import uniform
from time import sleep
+from typing import Any
import botocore.client
import botocore.exceptions
import botocore.waiter
from airflow.exceptions import AirflowException
-from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseAsyncHook,
AwsBaseHook
from airflow.typing_compat import Protocol, runtime_checkable
@@ -544,3 +546,238 @@ class BatchClientHook(AwsBaseHook):
delay = 1 + pow(tries * 0.6, 2)
delay = min(max_interval, delay)
return uniform(delay / 3, delay)
+
+
+class BatchClientAsyncHook(BatchClientHook, AwsBaseAsyncHook):
+ """
+ Async client for AWS Batch services.
+
+ :param job_id: the job ID, usually unknown (None) until the
+ submit_job operation gets the jobId defined by AWS Batch
+
+ :param waiters: an :py:class:`.BatchWaiters` object (see note below);
+ if None, polling is used with max_retries and status_retries.
+
+ .. note::
+ Several methods use a default random delay to check or poll for job
status, i.e.
+ ``random.sample()``
+ Using a random interval helps to avoid AWS API throttle limits
+ when many concurrent tasks request job-descriptions.
+
+ To modify the global defaults for the range of jitter allowed when a
+ random delay is used to check Batch job status, modify these defaults,
e.g.:
+
+ BatchClient.DEFAULT_DELAY_MIN = 0
+ BatchClient.DEFAULT_DELAY_MAX = 5
+
+ When explicit delay values are used, a 1 second random jitter is
applied to the
+ delay . It is generally recommended that random jitter is added to
API requests.
+ A convenience method is provided for this, e.g. to get a random delay
of
+ 10 sec +/- 5 sec: ``delay = BatchClient.add_jitter(10, width=5,
minima=0)``
+ """
+
+ def __init__(self, job_id: str | None, waiters: Any = None, *args: Any,
**kwargs: Any) -> None:
+ super().__init__(*args, **kwargs)
+ self.job_id = job_id
+ self.waiters = waiters
+
+ async def monitor_job(self) -> dict[str, str] | None:
+ """
+ Monitor an AWS Batch job
+ monitor_job can raise an exception or an AirflowTaskTimeout can be
raised if execution_timeout
+ is given while creating the task. These exceptions should be handled
in taskinstance.py
+ instead of here like it was previously done
+
+ :raises: AirflowException
+ """
+ if not self.job_id:
+ raise AirflowException("AWS Batch job - job_id was not found")
+
+ if self.waiters:
+ self.waiters.wait_for_job(self.job_id)
+ return None
+ else:
+ await self.wait_for_job(self.job_id)
+ await self.check_job_success(self.job_id)
+ success_msg = f"AWS Batch job ({self.job_id}) succeeded"
+ self.log.info(success_msg)
+ return {"status": "success", "message": success_msg}
+
+ async def check_job_success(self, job_id: str) -> bool: # type:
ignore[override]
+ """
+ Check the final status of the Batch job; return True if the job
+ 'SUCCEEDED', else raise an AirflowException
+
+ :param job_id: a Batch job ID
+
+ :raises: AirflowException
+ """
+ job = await self.get_job_description(job_id)
+ job_status = job.get("status")
+ if job_status == self.SUCCESS_STATE:
+ self.log.info("AWS Batch job (%s) succeeded: %s", job_id, job)
+ return True
+
+ if job_status == self.FAILURE_STATE:
+ raise AirflowException(f"AWS Batch job ({job_id}) failed: {job}")
+
+ if job_status in self.INTERMEDIATE_STATES:
+ raise AirflowException(f"AWS Batch job ({job_id}) is not complete:
{job}")
+
+ raise AirflowException(f"AWS Batch job ({job_id}) has unknown status:
{job}")
+
+ @staticmethod
+ async def delay(delay: int | float | None = None) -> None: # type:
ignore[override]
+ """
+ Pause execution for ``delay`` seconds.
+
+ :param delay: a delay to pause execution using ``time.sleep(delay)``;
+ a small 1 second jitter is applied to the delay.
+
+ .. note::
+ This method uses a default random delay, i.e.
+ ``random.sample()``;
+ using a random interval helps to avoid AWS API throttle limits
+ when many concurrent tasks request job-descriptions.
+ """
+ if delay is None:
+ delay = uniform(BatchClientHook.DEFAULT_DELAY_MIN,
BatchClientHook.DEFAULT_DELAY_MAX)
+ else:
+ delay = BatchClientAsyncHook.add_jitter(delay)
+ await asyncio.sleep(delay)
+
+ async def wait_for_job( # type: ignore[override]
+ self, job_id: str, delay: int | float | None = None
+ ) -> None:
+ """
+ Wait for Batch job to complete.
+
+ :param job_id: a Batch job ID
+
+ :param delay: a delay before polling for job status
+
+ :raises: AirflowException
+ """
+ await self.delay(delay)
+ await self.poll_for_job_running(job_id, delay)
+ await self.poll_for_job_complete(job_id, delay)
+ self.log.info("AWS Batch job (%s) has completed", job_id)
+
+ async def poll_for_job_complete( # type: ignore[override]
+ self, job_id: str, delay: int | float | None = None
+ ) -> None:
+ """
+ Poll for job completion. The status that indicates job completion
+ are: 'SUCCEEDED'|'FAILED'.
+
+ So the status options that this will wait for are the transitions from:
+
'SUBMITTED'>'PENDING'>'RUNNABLE'>'STARTING'>'RUNNING'>'SUCCEEDED'|'FAILED'
+
+ :param job_id: a Batch job ID
+
+ :param delay: a delay before polling for job status
+
+ :raises: AirflowException
+ """
+ await self.delay(delay)
+ complete_status = [self.SUCCESS_STATE, self.FAILURE_STATE]
+ await self.poll_job_status(job_id, complete_status)
+
+ async def poll_for_job_running( # type: ignore[override]
+ self, job_id: str, delay: int | float | None = None
+ ) -> None:
+ """
+ Poll for job running. The status that indicates a job is running or
+ already complete are: 'RUNNING'|'SUCCEEDED'|'FAILED'.
+
+ So the status options that this will wait for are the transitions from:
+
'SUBMITTED'>'PENDING'>'RUNNABLE'>'STARTING'>'RUNNING'|'SUCCEEDED'|'FAILED'
+
+ The completed status options are included for cases where the status
+ changes too quickly for polling to detect a RUNNING status that moves
+ quickly from STARTING to RUNNING to completed (often a failure).
+
+ :param job_id: a Batch job ID
+
+ :param delay: a delay before polling for job status
+
+ :raises: AirflowException
+ """
+ await self.delay(delay)
+ running_status = [self.RUNNING_STATE, self.SUCCESS_STATE,
self.FAILURE_STATE]
+ await self.poll_job_status(job_id, running_status)
+
+ async def get_job_description(self, job_id: str) -> dict[str, str]: #
type: ignore[override]
+ """
+ Get job description (using status_retries).
+
+ :param job_id: a Batch job ID
+ :raises: AirflowException
+ """
+ retries = 0
+ async with await self.get_client_async() as client:
+ while True:
+ try:
+ response = client.describe_jobs(jobs=[job_id])
+ return self.parse_job_description(job_id, response)
+
+ except botocore.exceptions.ClientError as err:
+ error = err.response.get("Error", {})
+ if error.get("Code") == "TooManyRequestsException":
+ pass # allow it to retry, if possible
+ else:
+ raise AirflowException(f"AWS Batch job ({job_id})
description error: {err}")
+
+ retries += 1
+ if retries >= self.status_retries:
+ raise AirflowException(
+ f"AWS Batch job ({job_id}) description error: exceeded
status_retries "
+ f"({self.status_retries})"
+ )
+
+ pause = self.exponential_delay(retries)
+ self.log.info(
+ "AWS Batch job (%s) description retry (%d of %d) in the
next %.2f seconds",
+ job_id,
+ retries,
+ self.status_retries,
+ pause,
+ )
+ await self.delay(pause)
+
+ async def poll_job_status(self, job_id: str, match_status: list[str]) ->
bool: # type: ignore[override]
+ """
+ Poll for job status using an exponential back-off strategy (with
max_retries).
+ The Batch job status polled are:
+
'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED'
+
+ :param job_id: a Batch job ID
+ :param match_status: a list of job status to match
+ :raises: AirflowException
+ """
+ retries = 0
+ while True:
+ job = await self.get_job_description(job_id)
+ job_status = job.get("status")
+ self.log.info(
+ "AWS Batch job (%s) check status (%s) in %s",
+ job_id,
+ job_status,
+ match_status,
+ )
+ if job_status in match_status:
+ return True
+
+ if retries >= self.max_retries:
+ raise AirflowException(f"AWS Batch job ({job_id}) status
checks exceed max_retries")
+
+ retries += 1
+ pause = self.exponential_delay(retries)
+ self.log.info(
+ "AWS Batch job (%s) status check (%d of %d) in the next %.2f
seconds",
+ job_id,
+ retries,
+ self.max_retries,
+ pause,
+ )
+ await self.delay(pause)
diff --git a/airflow/providers/amazon/aws/operators/batch.py
b/airflow/providers/amazon/aws/operators/batch.py
index 6565bcecfb..79a10a7b17 100644
--- a/airflow/providers/amazon/aws/operators/batch.py
+++ b/airflow/providers/amazon/aws/operators/batch.py
@@ -37,6 +37,7 @@ from airflow.providers.amazon.aws.links.batch import (
BatchJobQueueLink,
)
from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink
+from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger
from airflow.providers.amazon.aws.utils import trim_none_values
if TYPE_CHECKING:
@@ -71,6 +72,7 @@ class BatchOperator(BaseOperator):
Override the region_name in connection (if provided)
:param tags: collection of tags to apply to the AWS Batch job submission
if None, no tags are submitted
+ :param deferrable: Run operator in the deferrable mode.
.. note::
Any custom waiters must return a waiter for these calls:
@@ -125,6 +127,7 @@ class BatchOperator(BaseOperator):
region_name: str | None = None,
tags: dict | None = None,
wait_for_completion: bool = True,
+ deferrable: bool = False,
**kwargs,
):
@@ -139,6 +142,8 @@ class BatchOperator(BaseOperator):
self.waiters = waiters
self.tags = tags or {}
self.wait_for_completion = wait_for_completion
+ self.deferrable = deferrable
+
self.hook = BatchClientHook(
max_retries=max_retries,
status_retries=status_retries,
@@ -154,11 +159,43 @@ class BatchOperator(BaseOperator):
"""
self.submit_job(context)
+ if self.deferrable:
+ self.defer(
+ timeout=self.execution_timeout,
+ trigger=BatchOperatorTrigger(
+ job_id=self.job_id,
+ job_name=self.job_name,
+ job_definition=self.job_definition,
+ job_queue=self.job_queue,
+ overrides=self.overrides,
+ array_properties=self.array_properties,
+ parameters=self.parameters,
+ waiters=self.waiters,
+ tags=self.tags,
+ max_retries=self.hook.max_retries,
+ status_retries=self.hook.status_retries,
+ aws_conn_id=self.hook.aws_conn_id,
+ region_name=self.hook.region_name,
+ ),
+ method_name="execute_complete",
+ )
+
if self.wait_for_completion:
self.monitor_job(context)
return self.job_id
+ def execute_complete(self, context: Context, event: dict[str, Any]):
+ """
+ Callback for when the trigger fires - returns immediately.
+ Relies on trigger to throw an exception, otherwise it assumes
execution was
+ successful.
+ """
+ if "status" in event and event["status"] == "error":
+ raise AirflowException(event["message"])
+ self.log.info(event["message"])
+ return self.job_id
+
def on_kill(self):
response = self.hook.client.terminate_job(jobId=self.job_id,
reason="Task killed by the user")
self.log.info("AWS Batch job (%s) terminated: %s", self.job_id,
response)
diff --git a/airflow/providers/amazon/aws/triggers/batch.py
b/airflow/providers/amazon/aws/triggers/batch.py
new file mode 100644
index 0000000000..eb5a80a3c9
--- /dev/null
+++ b/airflow/providers/amazon/aws/triggers/batch.py
@@ -0,0 +1,123 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import Any, AsyncIterator
+
+from airflow.providers.amazon.aws.hooks.batch_client import
BatchClientAsyncHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class BatchOperatorTrigger(BaseTrigger):
+ """
+ Checks for the state of a previously submitted job to AWS Batch.
+ BatchOperatorTrigger is fired as deferred class with params to poll the
job state in Triggerer
+
+ :param job_id: the job ID, usually unknown (None) until the
+ submit_job operation gets the jobId defined by AWS Batch
+ :param job_name: the name for the job that will run on AWS Batch
(templated)
+ :param job_definition: the job definition name on AWS Batch
+ :param job_queue: the queue name on AWS Batch
+ :param overrides: the `containerOverrides` parameter for boto3 (templated)
+ :param array_properties: the `arrayProperties` parameter for boto3
+ :param parameters: the `parameters` for boto3 (templated)
+ :param waiters: a :class:`.BatchWaiters` object (see note below);
+ if None, polling is used with max_retries and status_retries.
+ :param tags: collection of tags to apply to the AWS Batch job submission
+ if None, no tags are submitted
+ :param max_retries: exponential back-off retries, 4200 = 48 hours;
+ polling is only used when waiters is None
+ :param status_retries: number of HTTP retries to get job status, 10;
+ polling is only used when waiters is None
+ :param aws_conn_id: connection id of AWS credentials / region name. If
None,
+ credential boto3 strategy will be used.
+ :param region_name: AWS region name to use .
+ Override the region_name in connection (if provided)
+ """
+
+ def __init__(
+ self,
+ job_id: str | None,
+ job_name: str,
+ job_definition: str,
+ job_queue: str,
+ overrides: dict[str, str],
+ array_properties: dict[str, str],
+ parameters: dict[str, str],
+ waiters: Any,
+ tags: dict[str, str],
+ max_retries: int,
+ status_retries: int,
+ region_name: str | None,
+ aws_conn_id: str | None = "aws_default",
+ ):
+ super().__init__()
+ self.job_id = job_id
+ self.job_name = job_name
+ self.job_definition = job_definition
+ self.job_queue = job_queue
+ self.overrides = overrides or {}
+ self.array_properties = array_properties or {}
+ self.parameters = parameters or {}
+ self.waiters = waiters
+ self.tags = tags or {}
+ self.max_retries = max_retries
+ self.status_retries = status_retries
+ self.aws_conn_id = aws_conn_id
+ self.region_name = region_name
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serializes BatchOperatorTrigger arguments and classpath."""
+ return (
+ "airflow.providers.amazon.aws.triggers.batch.BatchOperatorTrigger",
+ {
+ "job_id": self.job_id,
+ "job_name": self.job_name,
+ "job_definition": self.job_definition,
+ "job_queue": self.job_queue,
+ "overrides": self.overrides,
+ "array_properties": self.array_properties,
+ "parameters": self.parameters,
+ "waiters": self.waiters,
+ "tags": self.tags,
+ "max_retries": self.max_retries,
+ "status_retries": self.status_retries,
+ "aws_conn_id": self.aws_conn_id,
+ "region_name": self.region_name,
+ },
+ )
+
+ async def run(self) -> AsyncIterator["TriggerEvent"]:
+ """
+ Make async connection using aiobotocore library to AWS Batch,
+ periodically poll for the job status on the Triggerer
+
+ The status that indicates job completion are: 'SUCCEEDED'|'FAILED'.
+
+ So the status options that this will poll for are the transitions from:
+
'SUBMITTED'>'PENDING'>'RUNNABLE'>'STARTING'>'RUNNING'>'SUCCEEDED'|'FAILED'
+ """
+ hook = BatchClientAsyncHook(job_id=self.job_id, waiters=self.waiters,
aws_conn_id=self.aws_conn_id)
+ try:
+ response = await hook.monitor_job()
+ if response:
+ yield TriggerEvent(response)
+ else:
+ error_message = f"{self.job_id} failed"
+ yield TriggerEvent({"status": "error", "message":
error_message})
+ except Exception as e:
+ yield TriggerEvent({"status": "error", "message": str(e)})
diff --git a/docs/apache-airflow-providers-amazon/operators/batch.rst
b/docs/apache-airflow-providers-amazon/operators/batch.rst
index ba280cb38d..0c686184b9 100644
--- a/docs/apache-airflow-providers-amazon/operators/batch.rst
+++ b/docs/apache-airflow-providers-amazon/operators/batch.rst
@@ -37,7 +37,10 @@ Operators
Submit a new AWS Batch job
==========================
-To submit a new AWS Batch job and monitor it until it reaches a terminal state
you can
+To submit a new AWS Batch job and monitor it until it reaches a terminal state.
+You can also run this operator in deferrable mode by setting the parameter
``deferrable`` to True.
+This will lead to efficient utilization of Airflow workers as polling for job
status happens on
+the triggerer asynchronously. Note that this will need triggerer to be
available on your Airflow deployment.
use :class:`~airflow.providers.amazon.aws.operators.batch.BatchOperator`.
.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_batch.py
diff --git a/tests/providers/amazon/aws/deferrable/hooks/test_batch_client.py
b/tests/providers/amazon/aws/deferrable/hooks/test_batch_client.py
new file mode 100644
index 0000000000..10be746ef9
--- /dev/null
+++ b/tests/providers/amazon/aws/deferrable/hooks/test_batch_client.py
@@ -0,0 +1,213 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import sys
+
+import botocore
+import pytest
+
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.hooks.batch_client import
BatchClientAsyncHook
+
+if sys.version_info < (3, 8):
+ # For compatibility with Python 3.7
+ from asynctest import mock as async_mock
+else:
+ from unittest import mock as async_mock
+
+pytest.importorskip("aiobotocore")
+
+
+class TestBatchClientAsyncHook:
+ JOB_ID = "e2a459c5-381b-494d-b6e8-d6ee334db4e2"
+ BATCH_API_SUCCESS_RESPONSE = {"jobs": [{"jobId": JOB_ID, "status":
"SUCCEEDED"}]}
+
+ @pytest.mark.asyncio
+
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async")
+
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.poll_job_status")
+ async def test_monitor_job_with_success(self, mock_poll_job_status,
mock_client):
+ """Tests that the monitor_job method returns expected event once
successful"""
+ mock_poll_job_status.return_value = True
+
mock_client.return_value.__aenter__.return_value.describe_jobs.return_value = (
+ self.BATCH_API_SUCCESS_RESPONSE
+ )
+ hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None)
+ result = await hook.monitor_job()
+ assert result == {"status": "success", "message": f"AWS Batch job
({self.JOB_ID}) succeeded"}
+
+ @pytest.mark.asyncio
+
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async")
+
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.poll_job_status")
+ async def test_monitor_job_with_no_job_id(self, mock_poll_job_status,
mock_client):
+ """Tests that the monitor_job method raises expected exception when
incorrect job id is passed"""
+ mock_poll_job_status.return_value = True
+
mock_client.return_value.__aenter__.return_value.describe_jobs.return_value = (
+ self.BATCH_API_SUCCESS_RESPONSE
+ )
+
+ with pytest.raises(AirflowException) as exc_info:
+ hook = BatchClientAsyncHook(job_id=False, waiters=None)
+ await hook.monitor_job()
+ assert str(exc_info.value) == "AWS Batch job - job_id was not found"
+
+ @pytest.mark.asyncio
+
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async")
+
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.poll_job_status")
+ async def test_hit_api_throttle(self, mock_poll_job_status, mock_client):
+ """
+ Tests that the get_job_description method raises correct exception
when retries
+ exceed the threshold
+ """
+ mock_poll_job_status.return_value = True
+
mock_client.return_value.__aenter__.return_value.describe_jobs.side_effect = (
+ botocore.exceptions.ClientError(
+ error_response={
+ "Error": {
+ "Code": "TooManyRequestsException",
+ }
+ },
+ operation_name="get job description",
+ )
+ )
+ """status_retries = 2 ensures that exponential_delay block is covered
in batch_client.py
+ otherwise the code coverage will drop"""
+ hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None,
status_retries=2)
+ with pytest.raises(AirflowException) as exc_info:
+ await hook.get_job_description(job_id=self.JOB_ID)
+ assert (
+ str(exc_info.value) == f"AWS Batch job ({self.JOB_ID}) description
error: exceeded "
+ "status_retries (2)"
+ )
+
+ @pytest.mark.asyncio
+
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async")
+
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.poll_job_status")
+ async def test_client_error(self, mock_poll_job_status, mock_client):
+ """Test that the get_job_description method raises correct exception
when the error code
+ from boto3 api is not TooManyRequestsException"""
+ mock_poll_job_status.return_value = True
+
mock_client.return_value.__aenter__.return_value.describe_jobs.side_effect = (
+ botocore.exceptions.ClientError(
+ error_response={"Error": {"Code": "InvalidClientTokenId",
"Message": "Malformed Token"}},
+ operation_name="get job description",
+ )
+ )
+ hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None,
status_retries=1)
+ with pytest.raises(AirflowException) as exc_info:
+ await hook.get_job_description(job_id=self.JOB_ID)
+ assert (
+ str(exc_info.value) == f"AWS Batch job ({self.JOB_ID}) description
error: An error "
+ "occurred (InvalidClientTokenId) when calling the get job
description operation: "
+ "Malformed Token"
+ )
+
+ @pytest.mark.asyncio
+
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async")
+ async def test_check_job_success(self, mock_client):
+ """Tests that the check_job_success method returns True when job
succeeds"""
+
mock_client.return_value.__aenter__.return_value.describe_jobs.return_value = (
+ self.BATCH_API_SUCCESS_RESPONSE
+ )
+ hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None)
+ result = await hook.check_job_success(job_id=self.JOB_ID)
+ assert result is True
+
+ @pytest.mark.asyncio
+
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async")
+ async def test_check_job_raises_exception_failed(self, mock_client):
+ """Tests that the check_job_success method raises exception correctly
as per job state"""
+ mock_job = {"jobs": [{"jobId": self.JOB_ID, "status": "FAILED"}]}
+
mock_client.return_value.__aenter__.return_value.describe_jobs.return_value =
mock_job
+ hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None)
+ with pytest.raises(AirflowException) as exc_info:
+ await hook.check_job_success(job_id=self.JOB_ID)
+ assert str(exc_info.value) == f"AWS Batch job ({self.JOB_ID}) failed"
+ ": " + str(
+ mock_job["jobs"][0]
+ )
+
+ @pytest.mark.asyncio
+
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async")
+ async def test_check_job_raises_exception_pending(self, mock_client):
+ """Tests that the check_job_success method raises exception correctly
as per job state"""
+ mock_job = {"jobs": [{"jobId": self.JOB_ID, "status": "PENDING"}]}
+
mock_client.return_value.__aenter__.return_value.describe_jobs.return_value =
mock_job
+ hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None)
+ with pytest.raises(AirflowException) as exc_info:
+ await hook.check_job_success(job_id=self.JOB_ID)
+ assert str(exc_info.value) == f"AWS Batch job ({self.JOB_ID}) is not
complete" + ": " + str(
+ mock_job["jobs"][0]
+ )
+
+ @pytest.mark.asyncio
+
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async")
+ async def test_check_job_raises_exception_strange(self, mock_client):
+ """Tests that the check_job_success method raises exception correctly
as per job state"""
+ mock_job = {"jobs": [{"jobId": self.JOB_ID, "status": "STRANGE"}]}
+
mock_client.return_value.__aenter__.return_value.describe_jobs.return_value =
mock_job
+ hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None)
+ with pytest.raises(AirflowException) as exc_info:
+ await hook.check_job_success(job_id=self.JOB_ID)
+ assert str(exc_info.value) == f"AWS Batch job ({self.JOB_ID}) has
unknown status" + ": " + str(
+ mock_job["jobs"][0]
+ )
+
+ @pytest.mark.asyncio
+
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async")
+ async def test_check_job_raises_exception_runnable(self, mock_client):
+ """Tests that the check_job_success method raises exception correctly
as per job state"""
+ mock_job = {"jobs": [{"jobId": self.JOB_ID, "status": "RUNNABLE"}]}
+
mock_client.return_value.__aenter__.return_value.describe_jobs.return_value =
mock_job
+ hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None)
+ with pytest.raises(AirflowException) as exc_info:
+ await hook.check_job_success(job_id=self.JOB_ID)
+ assert str(exc_info.value) == f"AWS Batch job ({self.JOB_ID}) is not
complete" + ": " + str(
+ mock_job["jobs"][0]
+ )
+
+ @pytest.mark.asyncio
+
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async")
+ async def test_check_job_raises_exception_submitted(self, mock_client):
+ """Tests that the check_job_success method raises exception correctly
as per job state"""
+ mock_job = {"jobs": [{"jobId": self.JOB_ID, "status": "SUBMITTED"}]}
+
mock_client.return_value.__aenter__.return_value.describe_jobs.return_value =
mock_job
+ hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None)
+ with pytest.raises(AirflowException) as exc_info:
+ await hook.check_job_success(job_id=self.JOB_ID)
+ assert str(exc_info.value) == f"AWS Batch job ({self.JOB_ID}) is not
complete" + ": " + str(
+ mock_job["jobs"][0]
+ )
+
+ @pytest.mark.asyncio
+
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async")
+ async def test_poll_job_status_raises_for_max_retries(self, mock_client):
+ mock_job = {"jobs": [{"jobId": self.JOB_ID, "status": "RUNNABLE"}]}
+
mock_client.return_value.__aenter__.return_value.describe_jobs.return_value =
mock_job
+ hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None,
max_retries=1)
+ with pytest.raises(AirflowException) as exc_info:
+ await hook.poll_job_status(job_id=self.JOB_ID,
match_status=["SUCCEEDED"])
+ assert str(exc_info.value) == f"AWS Batch job ({self.JOB_ID}) status
checks exceed " "max_retries"
+
+ @pytest.mark.asyncio
+
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async")
+ async def test_poll_job_status_in_match_status(self, mock_client):
+ mock_job = self.BATCH_API_SUCCESS_RESPONSE
+
mock_client.return_value.__aenter__.return_value.describe_jobs.return_value =
mock_job
+ hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None,
max_retries=1)
+ result = await hook.poll_job_status(job_id=self.JOB_ID,
match_status=["SUCCEEDED"])
+ assert result is True
diff --git a/tests/providers/amazon/aws/deferrable/triggers/test_batch.py
b/tests/providers/amazon/aws/deferrable/triggers/test_batch.py
new file mode 100644
index 0000000000..ad534619f0
--- /dev/null
+++ b/tests/providers/amazon/aws/deferrable/triggers/test_batch.py
@@ -0,0 +1,131 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import importlib.util
+
+import pytest
+
+from airflow.providers.amazon.aws.triggers.batch import (
+ BatchOperatorTrigger,
+)
+from airflow.triggers.base import TriggerEvent
+from tests.providers.amazon.aws.utils.compat import async_mock
+
+JOB_NAME = "51455483-c62c-48ac-9b88-53a6a725baa3"
+JOB_ID = "8ba9d676-4108-4474-9dca-8bbac1da9b19"
+MAX_RETRIES = 2
+STATUS_RETRIES = 3
+POKE_INTERVAL = 5
+AWS_CONN_ID = "airflow_test"
+REGION_NAME = "eu-west-1"
+
+
[email protected](not bool(importlib.util.find_spec("aiobotocore")),
reason="aiobotocore require")
+class TestBatchOperatorTrigger:
+ TRIGGER = BatchOperatorTrigger(
+ job_id=JOB_ID,
+ job_name=JOB_NAME,
+ job_definition="hello-world",
+ job_queue="queue",
+ waiters=None,
+ tags={},
+ max_retries=MAX_RETRIES,
+ status_retries=STATUS_RETRIES,
+ parameters={},
+ overrides={},
+ array_properties={},
+ region_name="eu-west-1",
+ aws_conn_id="airflow_test",
+ )
+
+ def test_batch_trigger_serialization(self):
+ """
+ Asserts that the BatchOperatorTrigger correctly serializes its
arguments
+ and classpath.
+ """
+
+ classpath, kwargs = self.TRIGGER.serialize()
+ assert classpath ==
"airflow.providers.amazon.aws.triggers.batch.BatchOperatorTrigger"
+ assert kwargs == {
+ "job_id": JOB_ID,
+ "job_name": JOB_NAME,
+ "job_definition": "hello-world",
+ "job_queue": "queue",
+ "waiters": None,
+ "tags": {},
+ "max_retries": MAX_RETRIES,
+ "status_retries": STATUS_RETRIES,
+ "parameters": {},
+ "overrides": {},
+ "array_properties": {},
+ "region_name": "eu-west-1",
+ "aws_conn_id": "airflow_test",
+ }
+
+ @pytest.mark.asyncio
+ async def test_batch_trigger_run(self):
+ """Test that the task is not done when event is not returned from
trigger."""
+
+ task = asyncio.create_task(self.TRIGGER.run().__anext__())
+ await asyncio.sleep(0.5)
+ # TriggerEvent was not returned
+ assert task.done() is False
+
+ @pytest.mark.asyncio
+
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.monitor_job")
+ async def test_batch_trigger_completed(self, mock_response):
+ """Test if the success event is returned from trigger."""
+ mock_response.return_value = {"status": "success", "message": f"AWS
Batch job ({JOB_ID}) succeeded"}
+
+ generator = self.TRIGGER.run()
+ actual_response = await generator.asend(None)
+ assert (
+ TriggerEvent({"status": "success", "message": f"AWS Batch job
({JOB_ID}) succeeded"})
+ == actual_response
+ )
+
+ @pytest.mark.asyncio
+
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.monitor_job")
+ async def test_batch_trigger_failure(self, mock_response):
+ """Test if the failure event is returned from trigger."""
+ mock_response.return_value = {"status": "error", "message": f"{JOB_ID}
failed"}
+
+ generator = self.TRIGGER.run()
+ actual_response = await generator.asend(None)
+ assert TriggerEvent({"status": "error", "message": f"{JOB_ID}
failed"}) == actual_response
+
+ @pytest.mark.asyncio
+
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.monitor_job")
+ async def test_batch_trigger_none(self, mock_response):
+ """Test if the failure event is returned when there is no response
from hook."""
+ mock_response.return_value = None
+
+ generator = self.TRIGGER.run()
+ actual_response = await generator.asend(None)
+ assert TriggerEvent({"status": "error", "message": f"{JOB_ID}
failed"}) == actual_response
+
+ @pytest.mark.asyncio
+
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.monitor_job")
+ async def test_batch_trigger_exception(self, mock_response):
+ """Test if the exception is raised from trigger."""
+ mock_response.side_effect = Exception("Test exception")
+
+ task = [i async for i in self.TRIGGER.run()]
+ assert len(task) == 1
+ assert TriggerEvent({"status": "error", "message": "Test exception"})
in task
diff --git a/tests/providers/amazon/aws/hooks/test_batch_client.py
b/tests/providers/amazon/aws/hooks/test_batch_client.py
index 13726e5518..d7e06d9eb2 100644
--- a/tests/providers/amazon/aws/hooks/test_batch_client.py
+++ b/tests/providers/amazon/aws/hooks/test_batch_client.py
@@ -20,6 +20,7 @@ from __future__ import annotations
import logging
from unittest import mock
+import botocore
import botocore.exceptions
import pytest
diff --git a/tests/providers/amazon/aws/operators/test_batch.py
b/tests/providers/amazon/aws/operators/test_batch.py
index 0ddfcea591..2192b7c4e2 100644
--- a/tests/providers/amazon/aws/operators/test_batch.py
+++ b/tests/providers/amazon/aws/operators/test_batch.py
@@ -19,11 +19,18 @@ from __future__ import annotations
from unittest import mock
+import pendulum
import pytest
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, TaskDeferred
+from airflow.models import DAG
+from airflow.models.dagrun import DagRun
+from airflow.models.taskinstance import TaskInstance
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
from airflow.providers.amazon.aws.operators.batch import
BatchCreateComputeEnvironmentOperator, BatchOperator
+from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger
+from airflow.utils import timezone
+from airflow.utils.types import DagRunType
# Use dummy AWS credentials
AWS_REGION = "eu-west-1"
@@ -211,3 +218,114 @@ class TestBatchCreateComputeEnvironmentOperator:
computeResources=compute_resources,
tags=tags,
)
+
+
+def create_context(task, dag=None):
+ if dag is None:
+ dag = DAG(dag_id="dag")
+ tzinfo = pendulum.timezone("UTC")
+ execution_date = timezone.datetime(2022, 1, 1, 1, 0, 0, tzinfo=tzinfo)
+ dag_run = DagRun(
+ dag_id=dag.dag_id,
+ execution_date=execution_date,
+ run_id=DagRun.generate_run_id(DagRunType.MANUAL, execution_date),
+ )
+
+ task_instance = TaskInstance(task=task)
+ task_instance.dag_run = dag_run
+ task_instance.xcom_push = mock.Mock()
+ return {
+ "dag": dag,
+ "ts": execution_date.isoformat(),
+ "task": task,
+ "ti": task_instance,
+ "task_instance": task_instance,
+ "run_id": dag_run.run_id,
+ "dag_run": dag_run,
+ "execution_date": execution_date,
+ "data_interval_end": execution_date,
+ "logical_date": execution_date,
+ }
+
+
+class TestBatchOperatorAsync:
+ JOB_NAME = "51455483-c62c-48ac-9b88-53a6a725baa3"
+ JOB_ID = "8ba9d676-4108-4474-9dca-8bbac1da9b19"
+ MAX_RETRIES = 2
+ STATUS_RETRIES = 3
+ RESPONSE_WITHOUT_FAILURES = {
+ "jobName": JOB_NAME,
+ "jobId": JOB_ID,
+ }
+
+
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type")
+ def test_batch_op_async(self, get_client_type_mock):
+ get_client_type_mock.return_value.submit_job.return_value =
self.RESPONSE_WITHOUT_FAILURES
+ task = BatchOperator(
+ task_id="task",
+ job_name=self.JOB_NAME,
+ job_queue="queue",
+ job_definition="hello-world",
+ max_retries=self.MAX_RETRIES,
+ status_retries=self.STATUS_RETRIES,
+ parameters=None,
+ overrides={},
+ array_properties=None,
+ aws_conn_id="airflow_test",
+ region_name="eu-west-1",
+ tags={},
+ deferrable=True,
+ )
+ context = create_context(task)
+ with pytest.raises(TaskDeferred) as exc:
+ task.execute(context)
+ assert isinstance(exc.value.trigger, BatchOperatorTrigger), "Trigger
is not a BatchOperatorTrigger"
+
+ def test_batch_op_async_execute_failure(self):
+ """Tests that an AirflowException is raised in case of error event"""
+
+ task = BatchOperator(
+ task_id="task",
+ job_name=self.JOB_NAME,
+ job_queue="queue",
+ job_definition="hello-world",
+ max_retries=self.MAX_RETRIES,
+ status_retries=self.STATUS_RETRIES,
+ parameters=None,
+ overrides={},
+ array_properties=None,
+ aws_conn_id="airflow_test",
+ region_name="eu-west-1",
+ tags={},
+ deferrable=True,
+ )
+ with pytest.raises(AirflowException) as exc_info:
+ task.execute_complete(context=None, event={"status": "error",
"message": "test failure message"})
+
+ assert str(exc_info.value) == "test failure message"
+
+ @pytest.mark.parametrize(
+ "event",
+ [{"status": "success", "message": f"AWS Batch job ({JOB_ID})
succeeded"}],
+ )
+ def test_batch_op_async_execute_complete(self, caplog, event):
+ """Tests that execute_complete method returns None and that it prints
expected log"""
+ task = BatchOperator(
+ task_id="task",
+ job_name=self.JOB_NAME,
+ job_queue="queue",
+ job_definition="hello-world",
+ max_retries=self.MAX_RETRIES,
+ status_retries=self.STATUS_RETRIES,
+ parameters=None,
+ overrides={},
+ array_properties=None,
+ aws_conn_id="airflow_test",
+ region_name="eu-west-1",
+ tags={},
+ deferrable=True,
+ )
+ with mock.patch.object(task.log, "info") as mock_log_info:
+ assert task.execute_complete(context=None, event=event) is None
+
+ mock_log_info.assert_called_with(f"AWS Batch job ({self.JOB_ID})
succeeded")