This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5b0f668b6d Add deferrable param in BatchOperator (#30865)
5b0f668b6d is described below

commit 5b0f668b6dad448bdf99976658231f6ffa607f8b
Author: Pankaj Singh <[email protected]>
AuthorDate: Fri May 26 04:54:02 2023 +0530

    Add deferrable param in BatchOperator (#30865)
    
    Add the deferrable param in BatchOperator.
    This will allow running BatchOperator in an async way
    that means we only submit a job from the worker to run a batch job
    then defer to the trigger for polling and wait for a job the job status
    and the worker slot won't be occupied for the whole period of
    task execution.
---
 airflow/providers/amazon/aws/operators/batch.py    |  27 ++++++
 airflow/providers/amazon/aws/triggers/batch.py     | 107 +++++++++++++++++++++
 airflow/providers/amazon/aws/waiters/batch.json    |  25 +++++
 tests/providers/amazon/aws/operators/test_batch.py |  19 +++-
 tests/providers/amazon/aws/triggers/test_batch.py  |  69 +++++++++++++
 5 files changed, 246 insertions(+), 1 deletion(-)

diff --git a/airflow/providers/amazon/aws/operators/batch.py 
b/airflow/providers/amazon/aws/operators/batch.py
index 272122d109..cbeb0cbcba 100644
--- a/airflow/providers/amazon/aws/operators/batch.py
+++ b/airflow/providers/amazon/aws/operators/batch.py
@@ -38,6 +38,7 @@ from airflow.providers.amazon.aws.links.batch import (
     BatchJobQueueLink,
 )
 from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink
+from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger
 from airflow.providers.amazon.aws.utils import trim_none_values
 
 if TYPE_CHECKING:
@@ -79,6 +80,8 @@ class BatchOperator(BaseOperator):
         Override the region_name in connection (if provided)
     :param tags: collection of tags to apply to the AWS Batch job submission
         if None, no tags are submitted
+    :param deferrable: Run operator in the deferrable mode.
+    :param poll_interval: (Deferrable mode only) Time in seconds to wait 
between polling.
 
     .. note::
         Any custom waiters must return a waiter for these calls:
@@ -142,6 +145,8 @@ class BatchOperator(BaseOperator):
         region_name: str | None = None,
         tags: dict | None = None,
         wait_for_completion: bool = True,
+        deferrable: bool = False,
+        poll_interval: int = 30,
         **kwargs,
     ):
 
@@ -175,6 +180,8 @@ class BatchOperator(BaseOperator):
         self.waiters = waiters
         self.tags = tags or {}
         self.wait_for_completion = wait_for_completion
+        self.deferrable = deferrable
+        self.poll_interval = poll_interval
 
         # params for hook
         self.max_retries = max_retries
@@ -199,11 +206,31 @@ class BatchOperator(BaseOperator):
         """
         self.submit_job(context)
 
+        if self.deferrable:
+            self.defer(
+                timeout=self.execution_timeout,
+                trigger=BatchOperatorTrigger(
+                    job_id=self.job_id,
+                    max_retries=self.max_retries or 10,
+                    aws_conn_id=self.aws_conn_id,
+                    region_name=self.region_name,
+                    poll_interval=self.poll_interval,
+                ),
+                method_name="execute_complete",
+            )
+
         if self.wait_for_completion:
             self.monitor_job(context)
 
         return self.job_id
 
+    def execute_complete(self, context, event=None):
+        if event["status"] != "success":
+            raise AirflowException(f"Error while running job: {event}")
+        else:
+            self.log.info("Job completed.")
+        return event["job_id"]
+
     def on_kill(self):
         response = self.hook.client.terminate_job(jobId=self.job_id, 
reason="Task killed by the user")
         self.log.info("AWS Batch job (%s) terminated: %s", self.job_id, 
response)
diff --git a/airflow/providers/amazon/aws/triggers/batch.py 
b/airflow/providers/amazon/aws/triggers/batch.py
new file mode 100644
index 0000000000..fb60b7ea91
--- /dev/null
+++ b/airflow/providers/amazon/aws/triggers/batch.py
@@ -0,0 +1,107 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from typing import Any
+
+from botocore.exceptions import WaiterError
+
+from airflow.compat.functools import cached_property
+from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class BatchOperatorTrigger(BaseTrigger):
+    """
+    Trigger for BatchOperator.
+    The trigger will asynchronously poll the boto3 API and wait for the
+    Batch job to be in the `SUCCEEDED` state.
+
+    :param job_id:  A unique identifier for the cluster.
+    :param max_retries: The maximum number of attempts to be made.
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+    :param region_name: region name to use in AWS Hook
+    :param poll_interval: The amount of time in seconds to wait between 
attempts.
+    """
+
+    def __init__(
+        self,
+        job_id: str | None = None,
+        max_retries: int = 10,
+        aws_conn_id: str | None = "aws_default",
+        region_name: str | None = None,
+        poll_interval: int = 30,
+    ):
+        super().__init__()
+        self.job_id = job_id
+        self.max_retries = max_retries
+        self.aws_conn_id = aws_conn_id
+        self.region_name = region_name
+        self.poll_interval = poll_interval
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        """Serializes BatchOperatorTrigger arguments and classpath."""
+        return (
+            "airflow.providers.amazon.aws.triggers.batch.BatchOperatorTrigger",
+            {
+                "job_id": self.job_id,
+                "max_retries": self.max_retries,
+                "aws_conn_id": self.aws_conn_id,
+                "region_name": self.region_name,
+                "poll_interval": self.poll_interval,
+            },
+        )
+
+    @cached_property
+    def hook(self) -> BatchClientHook:
+        return BatchClientHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region_name)
+
+    async def run(self):
+
+        async with self.hook.async_conn as client:
+            waiter = self.hook.get_waiter("batch_job_complete", 
deferrable=True, client=client)
+            attempt = 0
+            while attempt < self.max_retries:
+                attempt = attempt + 1
+                try:
+                    await waiter.wait(
+                        jobs=[self.job_id],
+                        WaiterConfig={
+                            "Delay": self.poll_interval,
+                            "MaxAttempts": 1,
+                        },
+                    )
+                    break
+                except WaiterError as error:
+                    if "terminal failure" in str(error):
+                        yield TriggerEvent(
+                            {"status": "failure", "message": f"Delete Cluster 
Failed: {error}"}
+                        )
+                        break
+                    self.log.info(
+                        "Job status is %s. Retrying attempt %s/%s",
+                        error.last_response["jobs"][0]["status"],
+                        attempt,
+                        self.max_retries,
+                    )
+                    await asyncio.sleep(int(self.poll_interval))
+
+        if attempt >= self.max_retries:
+            yield TriggerEvent({"status": "failure", "message": "Job Failed - 
max attempts reached."})
+        else:
+            yield TriggerEvent({"status": "success", "job_id": self.job_id})
diff --git a/airflow/providers/amazon/aws/waiters/batch.json 
b/airflow/providers/amazon/aws/waiters/batch.json
new file mode 100644
index 0000000000..fa9752ea14
--- /dev/null
+++ b/airflow/providers/amazon/aws/waiters/batch.json
@@ -0,0 +1,25 @@
+{
+  "version": 2,
+  "waiters": {
+    "batch_job_complete": {
+      "delay": 300,
+      "operation": "DescribeJobs",
+      "maxAttempts": 100,
+      "description": "Wait until job is SUCCEEDED",
+      "acceptors": [
+        {
+          "argument": "jobs[].status",
+          "expected": "SUCCEEDED",
+          "matcher": "pathAll",
+          "state": "success"
+        },
+        {
+          "argument": "jobs[].status",
+          "expected": "FAILED",
+          "matcher": "pathAll",
+          "state": "failed"
+        }
+      ]
+    }
+  }
+}
diff --git a/tests/providers/amazon/aws/operators/test_batch.py 
b/tests/providers/amazon/aws/operators/test_batch.py
index 42f7fae86c..f559424dff 100644
--- a/tests/providers/amazon/aws/operators/test_batch.py
+++ b/tests/providers/amazon/aws/operators/test_batch.py
@@ -22,11 +22,13 @@ from unittest.mock import patch
 
 import pytest
 
-from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
+from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning, TaskDeferred
 from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
 from airflow.providers.amazon.aws.operators.batch import 
BatchCreateComputeEnvironmentOperator, BatchOperator
 
 # Use dummy AWS credentials
+from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger
+
 AWS_REGION = "eu-west-1"
 AWS_ACCESS_KEY_ID = "airflow_dummy_key"
 AWS_SECRET_ACCESS_KEY = "airflow_dummy_secret"
@@ -256,6 +258,21 @@ class TestBatchOperator:
                 container_overrides={"a": "b"},
             )
 
+    
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type")
+    def test_defer_if_deferrable_param_set(self, mock_client):
+        batch = BatchOperator(
+            task_id="task",
+            job_name=JOB_NAME,
+            job_queue="queue",
+            job_definition="hello-world",
+            do_xcom_push=False,
+            deferrable=True,
+        )
+
+        with pytest.raises(TaskDeferred) as exc:
+            batch.execute(context=None)
+        assert isinstance(exc.value.trigger, BatchOperatorTrigger), "Trigger 
is not a BatchOperatorTrigger"
+
 
 class TestBatchCreateComputeEnvironmentOperator:
     @mock.patch.object(BatchClientHook, "client")
diff --git a/tests/providers/amazon/aws/triggers/test_batch.py 
b/tests/providers/amazon/aws/triggers/test_batch.py
new file mode 100644
index 0000000000..54a8765ab4
--- /dev/null
+++ b/tests/providers/amazon/aws/triggers/test_batch.py
@@ -0,0 +1,69 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger
+from airflow.triggers.base import TriggerEvent
+from tests.providers.amazon.aws.utils.compat import AsyncMock, async_mock
+
+BATCH_JOB_ID = "job_id"
+POLL_INTERVAL = 5
+MAX_ATTEMPT = 5
+AWS_CONN_ID = "aws_batch_job_conn"
+AWS_REGION = "us-east-2"
+
+
+class TestBatchOperatorTrigger:
+    def test_batch_operator_trigger_serialize(self):
+        batch_trigger = BatchOperatorTrigger(
+            job_id=BATCH_JOB_ID,
+            poll_interval=POLL_INTERVAL,
+            max_retries=MAX_ATTEMPT,
+            aws_conn_id=AWS_CONN_ID,
+            region_name=AWS_REGION,
+        )
+        class_path, args = batch_trigger.serialize()
+        assert class_path == 
"airflow.providers.amazon.aws.triggers.batch.BatchOperatorTrigger"
+        assert args["job_id"] == BATCH_JOB_ID
+        assert args["poll_interval"] == POLL_INTERVAL
+        assert args["max_retries"] == MAX_ATTEMPT
+        assert args["aws_conn_id"] == AWS_CONN_ID
+        assert args["region_name"] == AWS_REGION
+
+    @pytest.mark.asyncio
+    
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_waiter")
+    
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.async_conn")
+    async def test_batch_job_trigger_run(self, mock_async_conn, 
mock_get_waiter):
+        mock = async_mock.MagicMock()
+        mock_async_conn.__aenter__.return_value = mock
+
+        mock_get_waiter().wait = AsyncMock()
+
+        batch_trigger = BatchOperatorTrigger(
+            job_id=BATCH_JOB_ID,
+            poll_interval=POLL_INTERVAL,
+            max_retries=MAX_ATTEMPT,
+            aws_conn_id=AWS_CONN_ID,
+            region_name=AWS_REGION,
+        )
+
+        generator = batch_trigger.run()
+        response = await generator.asend(None)
+
+        assert response == TriggerEvent({"status": "success", "job_id": 
BATCH_JOB_ID})

Reply via email to