This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 77c272e6e8 Add AWS deferrable BatchOperator (#29300)
77c272e6e8 is described below

commit 77c272e6e8ecda0ce48917064e58ba14f6a15844
Author: Rajath <92459020+rajaths010...@users.noreply.github.com>
AuthorDate: Wed Apr 5 20:27:13 2023 +0530

    Add AWS deferrable BatchOperator (#29300)
    
    This PR donates the following BatchOperator deferrable developed in 
[astronomer-providers](https://github.com/astronomer/astronomer-providers) repo 
to apache airflow.
---
 airflow/providers/amazon/aws/hooks/batch_client.py | 239 ++++++++++++++++++++-
 airflow/providers/amazon/aws/operators/batch.py    |  37 ++++
 airflow/providers/amazon/aws/triggers/batch.py     | 123 +++++++++++
 .../operators/batch.rst                            |   5 +-
 .../aws/deferrable/hooks/test_batch_client.py      | 213 ++++++++++++++++++
 .../amazon/aws/deferrable/triggers/test_batch.py   | 131 +++++++++++
 .../amazon/aws/hooks/test_batch_client.py          |   1 +
 tests/providers/amazon/aws/operators/test_batch.py | 120 ++++++++++-
 8 files changed, 866 insertions(+), 3 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/batch_client.py 
b/airflow/providers/amazon/aws/hooks/batch_client.py
index 526ab9a8a4..10b93afbbc 100644
--- a/airflow/providers/amazon/aws/hooks/batch_client.py
+++ b/airflow/providers/amazon/aws/hooks/batch_client.py
@@ -26,15 +26,17 @@ A client for AWS Batch services
 """
 from __future__ import annotations
 
+import asyncio
 from random import uniform
 from time import sleep
+from typing import Any
 
 import botocore.client
 import botocore.exceptions
 import botocore.waiter
 
 from airflow.exceptions import AirflowException
-from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseAsyncHook, 
AwsBaseHook
 from airflow.typing_compat import Protocol, runtime_checkable
 
 
@@ -544,3 +546,238 @@ class BatchClientHook(AwsBaseHook):
         delay = 1 + pow(tries * 0.6, 2)
         delay = min(max_interval, delay)
         return uniform(delay / 3, delay)
+
+
+class BatchClientAsyncHook(BatchClientHook, AwsBaseAsyncHook):
+    """
+    Async client for AWS Batch services.
+
+    :param job_id: the job ID, usually unknown (None) until the
+        submit_job operation gets the jobId defined by AWS Batch
+
+    :param waiters: an :py:class:`.BatchWaiters` object (see note below);
+        if None, polling is used with max_retries and status_retries.
+
+    .. note::
+        Several methods use a default random delay to check or poll for job 
status, i.e.
+        ``random.sample()``
+        Using a random interval helps to avoid AWS API throttle limits
+        when many concurrent tasks request job-descriptions.
+
+        To modify the global defaults for the range of jitter allowed when a
+        random delay is used to check Batch job status, modify these defaults, 
e.g.:
+
+            BatchClient.DEFAULT_DELAY_MIN = 0
+            BatchClient.DEFAULT_DELAY_MAX = 5
+
+        When explicit delay values are used, a 1 second random jitter is 
applied to the
+        delay .  It is generally recommended that random jitter is added to 
API requests.
+        A convenience method is provided for this, e.g. to get a random delay 
of
+        10 sec +/- 5 sec: ``delay = BatchClient.add_jitter(10, width=5, 
minima=0)``
+    """
+
+    def __init__(self, job_id: str | None, waiters: Any = None, *args: Any, 
**kwargs: Any) -> None:
+        super().__init__(*args, **kwargs)
+        self.job_id = job_id
+        self.waiters = waiters
+
+    async def monitor_job(self) -> dict[str, str] | None:
+        """
+        Monitor an AWS Batch job
+        monitor_job can raise an exception or an AirflowTaskTimeout can be 
raised if execution_timeout
+        is given while creating the task. These exceptions should be handled 
in taskinstance.py
+        instead of here like it was previously done
+
+        :raises: AirflowException
+        """
+        if not self.job_id:
+            raise AirflowException("AWS Batch job - job_id was not found")
+
+        if self.waiters:
+            self.waiters.wait_for_job(self.job_id)
+            return None
+        else:
+            await self.wait_for_job(self.job_id)
+            await self.check_job_success(self.job_id)
+            success_msg = f"AWS Batch job ({self.job_id}) succeeded"
+            self.log.info(success_msg)
+            return {"status": "success", "message": success_msg}
+
+    async def check_job_success(self, job_id: str) -> bool:  # type: 
ignore[override]
+        """
+        Check the final status of the Batch job; return True if the job
+        'SUCCEEDED', else raise an AirflowException
+
+        :param job_id: a Batch job ID
+
+        :raises: AirflowException
+        """
+        job = await self.get_job_description(job_id)
+        job_status = job.get("status")
+        if job_status == self.SUCCESS_STATE:
+            self.log.info("AWS Batch job (%s) succeeded: %s", job_id, job)
+            return True
+
+        if job_status == self.FAILURE_STATE:
+            raise AirflowException(f"AWS Batch job ({job_id}) failed: {job}")
+
+        if job_status in self.INTERMEDIATE_STATES:
+            raise AirflowException(f"AWS Batch job ({job_id}) is not complete: 
{job}")
+
+        raise AirflowException(f"AWS Batch job ({job_id}) has unknown status: 
{job}")
+
+    @staticmethod
+    async def delay(delay: int | float | None = None) -> None:  # type: 
ignore[override]
+        """
+        Pause execution for ``delay`` seconds.
+
+        :param delay: a delay to pause execution using ``time.sleep(delay)``;
+            a small 1 second jitter is applied to the delay.
+
+        .. note::
+            This method uses a default random delay, i.e.
+            ``random.sample()``;
+            using a random interval helps to avoid AWS API throttle limits
+            when many concurrent tasks request job-descriptions.
+        """
+        if delay is None:
+            delay = uniform(BatchClientHook.DEFAULT_DELAY_MIN, 
BatchClientHook.DEFAULT_DELAY_MAX)
+        else:
+            delay = BatchClientAsyncHook.add_jitter(delay)
+        await asyncio.sleep(delay)
+
+    async def wait_for_job(  # type: ignore[override]
+        self, job_id: str, delay: int | float | None = None
+    ) -> None:
+        """
+        Wait for Batch job to complete.
+
+        :param job_id: a Batch job ID
+
+        :param delay: a delay before polling for job status
+
+        :raises: AirflowException
+        """
+        await self.delay(delay)
+        await self.poll_for_job_running(job_id, delay)
+        await self.poll_for_job_complete(job_id, delay)
+        self.log.info("AWS Batch job (%s) has completed", job_id)
+
+    async def poll_for_job_complete(  # type: ignore[override]
+        self, job_id: str, delay: int | float | None = None
+    ) -> None:
+        """
+        Poll for job completion. The status that indicates job completion
+        are: 'SUCCEEDED'|'FAILED'.
+
+        So the status options that this will wait for are the transitions from:
+        
'SUBMITTED'>'PENDING'>'RUNNABLE'>'STARTING'>'RUNNING'>'SUCCEEDED'|'FAILED'
+
+        :param job_id: a Batch job ID
+
+        :param delay: a delay before polling for job status
+
+        :raises: AirflowException
+        """
+        await self.delay(delay)
+        complete_status = [self.SUCCESS_STATE, self.FAILURE_STATE]
+        await self.poll_job_status(job_id, complete_status)
+
+    async def poll_for_job_running(  # type: ignore[override]
+        self, job_id: str, delay: int | float | None = None
+    ) -> None:
+        """
+        Poll for job running. The status that indicates a job is running or
+        already complete are: 'RUNNING'|'SUCCEEDED'|'FAILED'.
+
+        So the status options that this will wait for are the transitions from:
+        
'SUBMITTED'>'PENDING'>'RUNNABLE'>'STARTING'>'RUNNING'|'SUCCEEDED'|'FAILED'
+
+        The completed status options are included for cases where the status
+        changes too quickly for polling to detect a RUNNING status that moves
+        quickly from STARTING to RUNNING to completed (often a failure).
+
+        :param job_id: a Batch job ID
+
+        :param delay: a delay before polling for job status
+
+        :raises: AirflowException
+        """
+        await self.delay(delay)
+        running_status = [self.RUNNING_STATE, self.SUCCESS_STATE, 
self.FAILURE_STATE]
+        await self.poll_job_status(job_id, running_status)
+
+    async def get_job_description(self, job_id: str) -> dict[str, str]:  # 
type: ignore[override]
+        """
+        Get job description (using status_retries).
+
+        :param job_id: a Batch job ID
+        :raises: AirflowException
+        """
+        retries = 0
+        async with await self.get_client_async() as client:
+            while True:
+                try:
+                    response = client.describe_jobs(jobs=[job_id])
+                    return self.parse_job_description(job_id, response)
+
+                except botocore.exceptions.ClientError as err:
+                    error = err.response.get("Error", {})
+                    if error.get("Code") == "TooManyRequestsException":
+                        pass  # allow it to retry, if possible
+                    else:
+                        raise AirflowException(f"AWS Batch job ({job_id}) 
description error: {err}")
+
+                retries += 1
+                if retries >= self.status_retries:
+                    raise AirflowException(
+                        f"AWS Batch job ({job_id}) description error: exceeded 
status_retries "
+                        f"({self.status_retries})"
+                    )
+
+                pause = self.exponential_delay(retries)
+                self.log.info(
+                    "AWS Batch job (%s) description retry (%d of %d) in the 
next %.2f seconds",
+                    job_id,
+                    retries,
+                    self.status_retries,
+                    pause,
+                )
+                await self.delay(pause)
+
+    async def poll_job_status(self, job_id: str, match_status: list[str]) -> 
bool:  # type: ignore[override]
+        """
+        Poll for job status using an exponential back-off strategy (with 
max_retries).
+        The Batch job status polled are:
+        
'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED'
+
+        :param job_id: a Batch job ID
+        :param match_status: a list of job status to match
+        :raises: AirflowException
+        """
+        retries = 0
+        while True:
+            job = await self.get_job_description(job_id)
+            job_status = job.get("status")
+            self.log.info(
+                "AWS Batch job (%s) check status (%s) in %s",
+                job_id,
+                job_status,
+                match_status,
+            )
+            if job_status in match_status:
+                return True
+
+            if retries >= self.max_retries:
+                raise AirflowException(f"AWS Batch job ({job_id}) status 
checks exceed max_retries")
+
+            retries += 1
+            pause = self.exponential_delay(retries)
+            self.log.info(
+                "AWS Batch job (%s) status check (%d of %d) in the next %.2f 
seconds",
+                job_id,
+                retries,
+                self.max_retries,
+                pause,
+            )
+            await self.delay(pause)
diff --git a/airflow/providers/amazon/aws/operators/batch.py 
b/airflow/providers/amazon/aws/operators/batch.py
index 6565bcecfb..79a10a7b17 100644
--- a/airflow/providers/amazon/aws/operators/batch.py
+++ b/airflow/providers/amazon/aws/operators/batch.py
@@ -37,6 +37,7 @@ from airflow.providers.amazon.aws.links.batch import (
     BatchJobQueueLink,
 )
 from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink
+from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger
 from airflow.providers.amazon.aws.utils import trim_none_values
 
 if TYPE_CHECKING:
@@ -71,6 +72,7 @@ class BatchOperator(BaseOperator):
         Override the region_name in connection (if provided)
     :param tags: collection of tags to apply to the AWS Batch job submission
         if None, no tags are submitted
+    :param deferrable: Run operator in the deferrable mode.
 
     .. note::
         Any custom waiters must return a waiter for these calls:
@@ -125,6 +127,7 @@ class BatchOperator(BaseOperator):
         region_name: str | None = None,
         tags: dict | None = None,
         wait_for_completion: bool = True,
+        deferrable: bool = False,
         **kwargs,
     ):
 
@@ -139,6 +142,8 @@ class BatchOperator(BaseOperator):
         self.waiters = waiters
         self.tags = tags or {}
         self.wait_for_completion = wait_for_completion
+        self.deferrable = deferrable
+
         self.hook = BatchClientHook(
             max_retries=max_retries,
             status_retries=status_retries,
@@ -154,11 +159,43 @@ class BatchOperator(BaseOperator):
         """
         self.submit_job(context)
 
+        if self.deferrable:
+            self.defer(
+                timeout=self.execution_timeout,
+                trigger=BatchOperatorTrigger(
+                    job_id=self.job_id,
+                    job_name=self.job_name,
+                    job_definition=self.job_definition,
+                    job_queue=self.job_queue,
+                    overrides=self.overrides,
+                    array_properties=self.array_properties,
+                    parameters=self.parameters,
+                    waiters=self.waiters,
+                    tags=self.tags,
+                    max_retries=self.hook.max_retries,
+                    status_retries=self.hook.status_retries,
+                    aws_conn_id=self.hook.aws_conn_id,
+                    region_name=self.hook.region_name,
+                ),
+                method_name="execute_complete",
+            )
+
         if self.wait_for_completion:
             self.monitor_job(context)
 
         return self.job_id
 
+    def execute_complete(self, context: Context, event: dict[str, Any]):
+        """
+        Callback for when the trigger fires - returns immediately.
+        Relies on trigger to throw an exception, otherwise it assumes 
execution was
+        successful.
+        """
+        if "status" in event and event["status"] == "error":
+            raise AirflowException(event["message"])
+        self.log.info(event["message"])
+        return self.job_id
+
     def on_kill(self):
         response = self.hook.client.terminate_job(jobId=self.job_id, 
reason="Task killed by the user")
         self.log.info("AWS Batch job (%s) terminated: %s", self.job_id, 
response)
diff --git a/airflow/providers/amazon/aws/triggers/batch.py 
b/airflow/providers/amazon/aws/triggers/batch.py
new file mode 100644
index 0000000000..eb5a80a3c9
--- /dev/null
+++ b/airflow/providers/amazon/aws/triggers/batch.py
@@ -0,0 +1,123 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import Any, AsyncIterator
+
+from airflow.providers.amazon.aws.hooks.batch_client import 
BatchClientAsyncHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class BatchOperatorTrigger(BaseTrigger):
+    """
+    Checks for the state of a previously submitted job to AWS Batch.
+    BatchOperatorTrigger is fired as deferred class with params to poll the 
job state in Triggerer
+
+    :param job_id: the job ID, usually unknown (None) until the
+        submit_job operation gets the jobId defined by AWS Batch
+    :param job_name: the name for the job that will run on AWS Batch 
(templated)
+    :param job_definition: the job definition name on AWS Batch
+    :param job_queue: the queue name on AWS Batch
+    :param overrides: the `containerOverrides` parameter for boto3 (templated)
+    :param array_properties: the `arrayProperties` parameter for boto3
+    :param parameters: the `parameters` for boto3 (templated)
+    :param waiters: a :class:`.BatchWaiters` object (see note below);
+        if None, polling is used with max_retries and status_retries.
+    :param tags: collection of tags to apply to the AWS Batch job submission
+        if None, no tags are submitted
+    :param max_retries: exponential back-off retries, 4200 = 48 hours;
+        polling is only used when waiters is None
+    :param status_retries: number of HTTP retries to get job status, 10;
+        polling is only used when waiters is None
+    :param aws_conn_id: connection id of AWS credentials / region name. If 
None,
+        credential boto3 strategy will be used.
+    :param region_name: AWS region name to use .
+        Override the region_name in connection (if provided)
+    """
+
+    def __init__(
+        self,
+        job_id: str | None,
+        job_name: str,
+        job_definition: str,
+        job_queue: str,
+        overrides: dict[str, str],
+        array_properties: dict[str, str],
+        parameters: dict[str, str],
+        waiters: Any,
+        tags: dict[str, str],
+        max_retries: int,
+        status_retries: int,
+        region_name: str | None,
+        aws_conn_id: str | None = "aws_default",
+    ):
+        super().__init__()
+        self.job_id = job_id
+        self.job_name = job_name
+        self.job_definition = job_definition
+        self.job_queue = job_queue
+        self.overrides = overrides or {}
+        self.array_properties = array_properties or {}
+        self.parameters = parameters or {}
+        self.waiters = waiters
+        self.tags = tags or {}
+        self.max_retries = max_retries
+        self.status_retries = status_retries
+        self.aws_conn_id = aws_conn_id
+        self.region_name = region_name
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        """Serializes BatchOperatorTrigger arguments and classpath."""
+        return (
+            "airflow.providers.amazon.aws.triggers.batch.BatchOperatorTrigger",
+            {
+                "job_id": self.job_id,
+                "job_name": self.job_name,
+                "job_definition": self.job_definition,
+                "job_queue": self.job_queue,
+                "overrides": self.overrides,
+                "array_properties": self.array_properties,
+                "parameters": self.parameters,
+                "waiters": self.waiters,
+                "tags": self.tags,
+                "max_retries": self.max_retries,
+                "status_retries": self.status_retries,
+                "aws_conn_id": self.aws_conn_id,
+                "region_name": self.region_name,
+            },
+        )
+
+    async def run(self) -> AsyncIterator["TriggerEvent"]:
+        """
+        Make async connection using aiobotocore library to AWS Batch,
+        periodically poll for the job status on the Triggerer
+
+        The status that indicates job completion are: 'SUCCEEDED'|'FAILED'.
+
+        So the status options that this will poll for are the transitions from:
+        
'SUBMITTED'>'PENDING'>'RUNNABLE'>'STARTING'>'RUNNING'>'SUCCEEDED'|'FAILED'
+        """
+        hook = BatchClientAsyncHook(job_id=self.job_id, waiters=self.waiters, 
aws_conn_id=self.aws_conn_id)
+        try:
+            response = await hook.monitor_job()
+            if response:
+                yield TriggerEvent(response)
+            else:
+                error_message = f"{self.job_id} failed"
+                yield TriggerEvent({"status": "error", "message": 
error_message})
+        except Exception as e:
+            yield TriggerEvent({"status": "error", "message": str(e)})
diff --git a/docs/apache-airflow-providers-amazon/operators/batch.rst 
b/docs/apache-airflow-providers-amazon/operators/batch.rst
index ba280cb38d..0c686184b9 100644
--- a/docs/apache-airflow-providers-amazon/operators/batch.rst
+++ b/docs/apache-airflow-providers-amazon/operators/batch.rst
@@ -37,7 +37,10 @@ Operators
 Submit a new AWS Batch job
 ==========================
 
-To submit a new AWS Batch job and monitor it until it reaches a terminal state 
you can
+To submit a new AWS Batch job and monitor it until it reaches a terminal state.
+You can also run this operator in deferrable mode by setting the parameter 
``deferrable`` to True.
+This will lead to efficient utilization of Airflow workers as polling for job 
status happens on
+the triggerer asynchronously. Note that this will need triggerer to be 
available on your Airflow deployment.
 use :class:`~airflow.providers.amazon.aws.operators.batch.BatchOperator`.
 
 .. exampleinclude:: /../../tests/system/providers/amazon/aws/example_batch.py
diff --git a/tests/providers/amazon/aws/deferrable/hooks/test_batch_client.py 
b/tests/providers/amazon/aws/deferrable/hooks/test_batch_client.py
new file mode 100644
index 0000000000..10be746ef9
--- /dev/null
+++ b/tests/providers/amazon/aws/deferrable/hooks/test_batch_client.py
@@ -0,0 +1,213 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import sys
+
+import botocore
+import pytest
+
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.hooks.batch_client import 
BatchClientAsyncHook
+
+if sys.version_info < (3, 8):
+    # For compatibility with Python 3.7
+    from asynctest import mock as async_mock
+else:
+    from unittest import mock as async_mock
+
+pytest.importorskip("aiobotocore")
+
+
+class TestBatchClientAsyncHook:
+    JOB_ID = "e2a459c5-381b-494d-b6e8-d6ee334db4e2"
+    BATCH_API_SUCCESS_RESPONSE = {"jobs": [{"jobId": JOB_ID, "status": 
"SUCCEEDED"}]}
+
+    @pytest.mark.asyncio
+    
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async")
+    
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.poll_job_status")
+    async def test_monitor_job_with_success(self, mock_poll_job_status, 
mock_client):
+        """Tests that the  monitor_job method returns expected event once 
successful"""
+        mock_poll_job_status.return_value = True
+        
mock_client.return_value.__aenter__.return_value.describe_jobs.return_value = (
+            self.BATCH_API_SUCCESS_RESPONSE
+        )
+        hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None)
+        result = await hook.monitor_job()
+        assert result == {"status": "success", "message": f"AWS Batch job 
({self.JOB_ID}) succeeded"}
+
+    @pytest.mark.asyncio
+    
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async")
+    
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.poll_job_status")
+    async def test_monitor_job_with_no_job_id(self, mock_poll_job_status, 
mock_client):
+        """Tests that the monitor_job method raises expected exception when 
incorrect job id is passed"""
+        mock_poll_job_status.return_value = True
+        
mock_client.return_value.__aenter__.return_value.describe_jobs.return_value = (
+            self.BATCH_API_SUCCESS_RESPONSE
+        )
+
+        with pytest.raises(AirflowException) as exc_info:
+            hook = BatchClientAsyncHook(job_id=False, waiters=None)
+            await hook.monitor_job()
+        assert str(exc_info.value) == "AWS Batch job - job_id was not found"
+
+    @pytest.mark.asyncio
+    
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async")
+    
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.poll_job_status")
+    async def test_hit_api_throttle(self, mock_poll_job_status, mock_client):
+        """
+        Tests that the get_job_description method raises  correct exception 
when retries
+        exceed the threshold
+        """
+        mock_poll_job_status.return_value = True
+        
mock_client.return_value.__aenter__.return_value.describe_jobs.side_effect = (
+            botocore.exceptions.ClientError(
+                error_response={
+                    "Error": {
+                        "Code": "TooManyRequestsException",
+                    }
+                },
+                operation_name="get job description",
+            )
+        )
+        """status_retries = 2 ensures that exponential_delay block is covered 
in batch_client.py
+        otherwise the code coverage will drop"""
+        hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None, 
status_retries=2)
+        with pytest.raises(AirflowException) as exc_info:
+            await hook.get_job_description(job_id=self.JOB_ID)
+        assert (
+            str(exc_info.value) == f"AWS Batch job ({self.JOB_ID}) description 
error: exceeded "
+            "status_retries (2)"
+        )
+
+    @pytest.mark.asyncio
+    
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async")
+    
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.poll_job_status")
+    async def test_client_error(self, mock_poll_job_status, mock_client):
+        """Test that the get_job_description method raises  correct exception 
when the error code
+        from boto3 api is not TooManyRequestsException"""
+        mock_poll_job_status.return_value = True
+        
mock_client.return_value.__aenter__.return_value.describe_jobs.side_effect = (
+            botocore.exceptions.ClientError(
+                error_response={"Error": {"Code": "InvalidClientTokenId", 
"Message": "Malformed Token"}},
+                operation_name="get job description",
+            )
+        )
+        hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None, 
status_retries=1)
+        with pytest.raises(AirflowException) as exc_info:
+            await hook.get_job_description(job_id=self.JOB_ID)
+        assert (
+            str(exc_info.value) == f"AWS Batch job ({self.JOB_ID}) description 
error: An error "
+            "occurred (InvalidClientTokenId) when calling the get job 
description operation: "
+            "Malformed Token"
+        )
+
+    @pytest.mark.asyncio
+    
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async")
+    async def test_check_job_success(self, mock_client):
+        """Tests that the check_job_success method returns True when job 
succeeds"""
+        
mock_client.return_value.__aenter__.return_value.describe_jobs.return_value = (
+            self.BATCH_API_SUCCESS_RESPONSE
+        )
+        hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None)
+        result = await hook.check_job_success(job_id=self.JOB_ID)
+        assert result is True
+
+    @pytest.mark.asyncio
+    
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async")
+    async def test_check_job_raises_exception_failed(self, mock_client):
+        """Tests that the check_job_success method raises exception correctly 
as per job state"""
+        mock_job = {"jobs": [{"jobId": self.JOB_ID, "status": "FAILED"}]}
+        
mock_client.return_value.__aenter__.return_value.describe_jobs.return_value = 
mock_job
+        hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None)
+        with pytest.raises(AirflowException) as exc_info:
+            await hook.check_job_success(job_id=self.JOB_ID)
+        assert str(exc_info.value) == f"AWS Batch job ({self.JOB_ID}) failed" 
+ ": " + str(
+            mock_job["jobs"][0]
+        )
+
+    @pytest.mark.asyncio
+    
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async")
+    async def test_check_job_raises_exception_pending(self, mock_client):
+        """Tests that the check_job_success method raises exception correctly 
as per job state"""
+        mock_job = {"jobs": [{"jobId": self.JOB_ID, "status": "PENDING"}]}
+        
mock_client.return_value.__aenter__.return_value.describe_jobs.return_value = 
mock_job
+        hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None)
+        with pytest.raises(AirflowException) as exc_info:
+            await hook.check_job_success(job_id=self.JOB_ID)
+        assert str(exc_info.value) == f"AWS Batch job ({self.JOB_ID}) is not 
complete" + ": " + str(
+            mock_job["jobs"][0]
+        )
+
+    @pytest.mark.asyncio
+    
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async")
+    async def test_check_job_raises_exception_strange(self, mock_client):
+        """Tests that the check_job_success method raises exception correctly 
as per job state"""
+        mock_job = {"jobs": [{"jobId": self.JOB_ID, "status": "STRANGE"}]}
+        
mock_client.return_value.__aenter__.return_value.describe_jobs.return_value = 
mock_job
+        hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None)
+        with pytest.raises(AirflowException) as exc_info:
+            await hook.check_job_success(job_id=self.JOB_ID)
+        assert str(exc_info.value) == f"AWS Batch job ({self.JOB_ID}) has 
unknown status" + ": " + str(
+            mock_job["jobs"][0]
+        )
+
+    @pytest.mark.asyncio
+    
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async")
+    async def test_check_job_raises_exception_runnable(self, mock_client):
+        """Tests that the check_job_success method raises exception correctly 
as per job state"""
+        mock_job = {"jobs": [{"jobId": self.JOB_ID, "status": "RUNNABLE"}]}
+        
mock_client.return_value.__aenter__.return_value.describe_jobs.return_value = 
mock_job
+        hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None)
+        with pytest.raises(AirflowException) as exc_info:
+            await hook.check_job_success(job_id=self.JOB_ID)
+        assert str(exc_info.value) == f"AWS Batch job ({self.JOB_ID}) is not 
complete" + ": " + str(
+            mock_job["jobs"][0]
+        )
+
+    @pytest.mark.asyncio
+    
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async")
+    async def test_check_job_raises_exception_submitted(self, mock_client):
+        """Tests that the check_job_success method raises exception correctly 
as per job state"""
+        mock_job = {"jobs": [{"jobId": self.JOB_ID, "status": "SUBMITTED"}]}
+        
mock_client.return_value.__aenter__.return_value.describe_jobs.return_value = 
mock_job
+        hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None)
+        with pytest.raises(AirflowException) as exc_info:
+            await hook.check_job_success(job_id=self.JOB_ID)
+        assert str(exc_info.value) == f"AWS Batch job ({self.JOB_ID}) is not 
complete" + ": " + str(
+            mock_job["jobs"][0]
+        )
+
+    @pytest.mark.asyncio
+    
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async")
+    async def test_poll_job_status_raises_for_max_retries(self, mock_client):
+        mock_job = {"jobs": [{"jobId": self.JOB_ID, "status": "RUNNABLE"}]}
+        
mock_client.return_value.__aenter__.return_value.describe_jobs.return_value = 
mock_job
+        hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None, 
max_retries=1)
+        with pytest.raises(AirflowException) as exc_info:
+            await hook.poll_job_status(job_id=self.JOB_ID, 
match_status=["SUCCEEDED"])
+        assert str(exc_info.value) == f"AWS Batch job ({self.JOB_ID}) status 
checks exceed " "max_retries"
+
+    @pytest.mark.asyncio
+    
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async")
+    async def test_poll_job_status_in_match_status(self, mock_client):
+        mock_job = self.BATCH_API_SUCCESS_RESPONSE
+        
mock_client.return_value.__aenter__.return_value.describe_jobs.return_value = 
mock_job
+        hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None, 
max_retries=1)
+        result = await hook.poll_job_status(job_id=self.JOB_ID, 
match_status=["SUCCEEDED"])
+        assert result is True
diff --git a/tests/providers/amazon/aws/deferrable/triggers/test_batch.py 
b/tests/providers/amazon/aws/deferrable/triggers/test_batch.py
new file mode 100644
index 0000000000..ad534619f0
--- /dev/null
+++ b/tests/providers/amazon/aws/deferrable/triggers/test_batch.py
@@ -0,0 +1,131 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import importlib.util
+
+import pytest
+
+from airflow.providers.amazon.aws.triggers.batch import (
+    BatchOperatorTrigger,
+)
+from airflow.triggers.base import TriggerEvent
+from tests.providers.amazon.aws.utils.compat import async_mock
+
+JOB_NAME = "51455483-c62c-48ac-9b88-53a6a725baa3"
+JOB_ID = "8ba9d676-4108-4474-9dca-8bbac1da9b19"
+MAX_RETRIES = 2
+STATUS_RETRIES = 3
+POKE_INTERVAL = 5
+AWS_CONN_ID = "airflow_test"
+REGION_NAME = "eu-west-1"
+
+
+@pytest.mark.skipif(not bool(importlib.util.find_spec("aiobotocore")), 
reason="aiobotocore require")
+class TestBatchOperatorTrigger:
+    TRIGGER = BatchOperatorTrigger(
+        job_id=JOB_ID,
+        job_name=JOB_NAME,
+        job_definition="hello-world",
+        job_queue="queue",
+        waiters=None,
+        tags={},
+        max_retries=MAX_RETRIES,
+        status_retries=STATUS_RETRIES,
+        parameters={},
+        overrides={},
+        array_properties={},
+        region_name="eu-west-1",
+        aws_conn_id="airflow_test",
+    )
+
+    def test_batch_trigger_serialization(self):
+        """
+        Asserts that the BatchOperatorTrigger correctly serializes its 
arguments
+        and classpath.
+        """
+
+        classpath, kwargs = self.TRIGGER.serialize()
+        assert classpath == 
"airflow.providers.amazon.aws.triggers.batch.BatchOperatorTrigger"
+        assert kwargs == {
+            "job_id": JOB_ID,
+            "job_name": JOB_NAME,
+            "job_definition": "hello-world",
+            "job_queue": "queue",
+            "waiters": None,
+            "tags": {},
+            "max_retries": MAX_RETRIES,
+            "status_retries": STATUS_RETRIES,
+            "parameters": {},
+            "overrides": {},
+            "array_properties": {},
+            "region_name": "eu-west-1",
+            "aws_conn_id": "airflow_test",
+        }
+
+    @pytest.mark.asyncio
+    async def test_batch_trigger_run(self):
+        """Test that the task is not done when event is not returned from 
trigger."""
+
+        task = asyncio.create_task(self.TRIGGER.run().__anext__())
+        await asyncio.sleep(0.5)
+        # TriggerEvent was not returned
+        assert task.done() is False
+
+    @pytest.mark.asyncio
+    
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.monitor_job")
+    async def test_batch_trigger_completed(self, mock_response):
+        """Test if the success event is  returned from trigger."""
+        mock_response.return_value = {"status": "success", "message": f"AWS 
Batch job ({JOB_ID}) succeeded"}
+
+        generator = self.TRIGGER.run()
+        actual_response = await generator.asend(None)
+        assert (
+            TriggerEvent({"status": "success", "message": f"AWS Batch job 
({JOB_ID}) succeeded"})
+            == actual_response
+        )
+
+    @pytest.mark.asyncio
+    
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.monitor_job")
+    async def test_batch_trigger_failure(self, mock_response):
+        """Test if the failure event is returned from trigger."""
+        mock_response.return_value = {"status": "error", "message": f"{JOB_ID} 
failed"}
+
+        generator = self.TRIGGER.run()
+        actual_response = await generator.asend(None)
+        assert TriggerEvent({"status": "error", "message": f"{JOB_ID} 
failed"}) == actual_response
+
+    @pytest.mark.asyncio
+    
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.monitor_job")
+    async def test_batch_trigger_none(self, mock_response):
+        """Test if the failure event is returned when there is no response 
from hook."""
+        mock_response.return_value = None
+
+        generator = self.TRIGGER.run()
+        actual_response = await generator.asend(None)
+        assert TriggerEvent({"status": "error", "message": f"{JOB_ID} 
failed"}) == actual_response
+
+    @pytest.mark.asyncio
+    
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.monitor_job")
+    async def test_batch_trigger_exception(self, mock_response):
+        """Test if the exception is raised from trigger."""
+        mock_response.side_effect = Exception("Test exception")
+
+        task = [i async for i in self.TRIGGER.run()]
+        assert len(task) == 1
+        assert TriggerEvent({"status": "error", "message": "Test exception"}) 
in task
diff --git a/tests/providers/amazon/aws/hooks/test_batch_client.py 
b/tests/providers/amazon/aws/hooks/test_batch_client.py
index 13726e5518..d7e06d9eb2 100644
--- a/tests/providers/amazon/aws/hooks/test_batch_client.py
+++ b/tests/providers/amazon/aws/hooks/test_batch_client.py
@@ -20,6 +20,7 @@ from __future__ import annotations
 import logging
 from unittest import mock
 
+import botocore
 import botocore.exceptions
 import pytest
 
diff --git a/tests/providers/amazon/aws/operators/test_batch.py 
b/tests/providers/amazon/aws/operators/test_batch.py
index 0ddfcea591..2192b7c4e2 100644
--- a/tests/providers/amazon/aws/operators/test_batch.py
+++ b/tests/providers/amazon/aws/operators/test_batch.py
@@ -19,11 +19,18 @@ from __future__ import annotations
 
 from unittest import mock
 
+import pendulum
 import pytest
 
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, TaskDeferred
+from airflow.models import DAG
+from airflow.models.dagrun import DagRun
+from airflow.models.taskinstance import TaskInstance
 from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
 from airflow.providers.amazon.aws.operators.batch import 
BatchCreateComputeEnvironmentOperator, BatchOperator
+from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger
+from airflow.utils import timezone
+from airflow.utils.types import DagRunType
 
 # Use dummy AWS credentials
 AWS_REGION = "eu-west-1"
@@ -211,3 +218,114 @@ class TestBatchCreateComputeEnvironmentOperator:
             computeResources=compute_resources,
             tags=tags,
         )
+
+
+def create_context(task, dag=None):
+    if dag is None:
+        dag = DAG(dag_id="dag")
+    tzinfo = pendulum.timezone("UTC")
+    execution_date = timezone.datetime(2022, 1, 1, 1, 0, 0, tzinfo=tzinfo)
+    dag_run = DagRun(
+        dag_id=dag.dag_id,
+        execution_date=execution_date,
+        run_id=DagRun.generate_run_id(DagRunType.MANUAL, execution_date),
+    )
+
+    task_instance = TaskInstance(task=task)
+    task_instance.dag_run = dag_run
+    task_instance.xcom_push = mock.Mock()
+    return {
+        "dag": dag,
+        "ts": execution_date.isoformat(),
+        "task": task,
+        "ti": task_instance,
+        "task_instance": task_instance,
+        "run_id": dag_run.run_id,
+        "dag_run": dag_run,
+        "execution_date": execution_date,
+        "data_interval_end": execution_date,
+        "logical_date": execution_date,
+    }
+
+
+class TestBatchOperatorAsync:
+    JOB_NAME = "51455483-c62c-48ac-9b88-53a6a725baa3"
+    JOB_ID = "8ba9d676-4108-4474-9dca-8bbac1da9b19"
+    MAX_RETRIES = 2
+    STATUS_RETRIES = 3
+    RESPONSE_WITHOUT_FAILURES = {
+        "jobName": JOB_NAME,
+        "jobId": JOB_ID,
+    }
+
+    
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type")
+    def test_batch_op_async(self, get_client_type_mock):
+        get_client_type_mock.return_value.submit_job.return_value = 
self.RESPONSE_WITHOUT_FAILURES
+        task = BatchOperator(
+            task_id="task",
+            job_name=self.JOB_NAME,
+            job_queue="queue",
+            job_definition="hello-world",
+            max_retries=self.MAX_RETRIES,
+            status_retries=self.STATUS_RETRIES,
+            parameters=None,
+            overrides={},
+            array_properties=None,
+            aws_conn_id="airflow_test",
+            region_name="eu-west-1",
+            tags={},
+            deferrable=True,
+        )
+        context = create_context(task)
+        with pytest.raises(TaskDeferred) as exc:
+            task.execute(context)
+        assert isinstance(exc.value.trigger, BatchOperatorTrigger), "Trigger 
is not a BatchOperatorTrigger"
+
+    def test_batch_op_async_execute_failure(self):
+        """Tests that an AirflowException is raised in case of error event"""
+
+        task = BatchOperator(
+            task_id="task",
+            job_name=self.JOB_NAME,
+            job_queue="queue",
+            job_definition="hello-world",
+            max_retries=self.MAX_RETRIES,
+            status_retries=self.STATUS_RETRIES,
+            parameters=None,
+            overrides={},
+            array_properties=None,
+            aws_conn_id="airflow_test",
+            region_name="eu-west-1",
+            tags={},
+            deferrable=True,
+        )
+        with pytest.raises(AirflowException) as exc_info:
+            task.execute_complete(context=None, event={"status": "error", 
"message": "test failure message"})
+
+        assert str(exc_info.value) == "test failure message"
+
+    @pytest.mark.parametrize(
+        "event",
+        [{"status": "success", "message": f"AWS Batch job ({JOB_ID}) 
succeeded"}],
+    )
+    def test_batch_op_async_execute_complete(self, caplog, event):
+        """Tests that execute_complete method returns None and that it prints 
expected log"""
+        task = BatchOperator(
+            task_id="task",
+            job_name=self.JOB_NAME,
+            job_queue="queue",
+            job_definition="hello-world",
+            max_retries=self.MAX_RETRIES,
+            status_retries=self.STATUS_RETRIES,
+            parameters=None,
+            overrides={},
+            array_properties=None,
+            aws_conn_id="airflow_test",
+            region_name="eu-west-1",
+            tags={},
+            deferrable=True,
+        )
+        with mock.patch.object(task.log, "info") as mock_log_info:
+            assert task.execute_complete(context=None, event=event) is None
+
+        mock_log_info.assert_called_with(f"AWS Batch job ({self.JOB_ID}) 
succeeded")


Reply via email to