This is an automated email from the ASF dual-hosted git repository.
jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 5b0f668b6d Add deferrable param in BatchOperator (#30865)
5b0f668b6d is described below
commit 5b0f668b6dad448bdf99976658231f6ffa607f8b
Author: Pankaj Singh <[email protected]>
AuthorDate: Fri May 26 04:54:02 2023 +0530
Add deferrable param in BatchOperator (#30865)
Add the deferrable param in BatchOperator.
This will allow running BatchOperator in an async way
that means we only submit a job from the worker to run a batch job
then defer to the trigger for polling and wait for a job the job status
and the worker slot won't be occupied for the whole period of
task execution.
---
airflow/providers/amazon/aws/operators/batch.py | 27 ++++++
airflow/providers/amazon/aws/triggers/batch.py | 107 +++++++++++++++++++++
airflow/providers/amazon/aws/waiters/batch.json | 25 +++++
tests/providers/amazon/aws/operators/test_batch.py | 19 +++-
tests/providers/amazon/aws/triggers/test_batch.py | 69 +++++++++++++
5 files changed, 246 insertions(+), 1 deletion(-)
diff --git a/airflow/providers/amazon/aws/operators/batch.py
b/airflow/providers/amazon/aws/operators/batch.py
index 272122d109..cbeb0cbcba 100644
--- a/airflow/providers/amazon/aws/operators/batch.py
+++ b/airflow/providers/amazon/aws/operators/batch.py
@@ -38,6 +38,7 @@ from airflow.providers.amazon.aws.links.batch import (
BatchJobQueueLink,
)
from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink
+from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger
from airflow.providers.amazon.aws.utils import trim_none_values
if TYPE_CHECKING:
@@ -79,6 +80,8 @@ class BatchOperator(BaseOperator):
Override the region_name in connection (if provided)
:param tags: collection of tags to apply to the AWS Batch job submission
if None, no tags are submitted
+ :param deferrable: Run operator in the deferrable mode.
+ :param poll_interval: (Deferrable mode only) Time in seconds to wait
between polling.
.. note::
Any custom waiters must return a waiter for these calls:
@@ -142,6 +145,8 @@ class BatchOperator(BaseOperator):
region_name: str | None = None,
tags: dict | None = None,
wait_for_completion: bool = True,
+ deferrable: bool = False,
+ poll_interval: int = 30,
**kwargs,
):
@@ -175,6 +180,8 @@ class BatchOperator(BaseOperator):
self.waiters = waiters
self.tags = tags or {}
self.wait_for_completion = wait_for_completion
+ self.deferrable = deferrable
+ self.poll_interval = poll_interval
# params for hook
self.max_retries = max_retries
@@ -199,11 +206,31 @@ class BatchOperator(BaseOperator):
"""
self.submit_job(context)
+ if self.deferrable:
+ self.defer(
+ timeout=self.execution_timeout,
+ trigger=BatchOperatorTrigger(
+ job_id=self.job_id,
+ max_retries=self.max_retries or 10,
+ aws_conn_id=self.aws_conn_id,
+ region_name=self.region_name,
+ poll_interval=self.poll_interval,
+ ),
+ method_name="execute_complete",
+ )
+
if self.wait_for_completion:
self.monitor_job(context)
return self.job_id
+ def execute_complete(self, context, event=None):
+ if event["status"] != "success":
+ raise AirflowException(f"Error while running job: {event}")
+ else:
+ self.log.info("Job completed.")
+ return event["job_id"]
+
def on_kill(self):
response = self.hook.client.terminate_job(jobId=self.job_id,
reason="Task killed by the user")
self.log.info("AWS Batch job (%s) terminated: %s", self.job_id,
response)
diff --git a/airflow/providers/amazon/aws/triggers/batch.py
b/airflow/providers/amazon/aws/triggers/batch.py
new file mode 100644
index 0000000000..fb60b7ea91
--- /dev/null
+++ b/airflow/providers/amazon/aws/triggers/batch.py
@@ -0,0 +1,107 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from typing import Any
+
+from botocore.exceptions import WaiterError
+
+from airflow.compat.functools import cached_property
+from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class BatchOperatorTrigger(BaseTrigger):
+ """
+ Trigger for BatchOperator.
+ The trigger will asynchronously poll the boto3 API and wait for the
+ Batch job to be in the `SUCCEEDED` state.
+
+ :param job_id: A unique identifier for the cluster.
+ :param max_retries: The maximum number of attempts to be made.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ :param region_name: region name to use in AWS Hook
+ :param poll_interval: The amount of time in seconds to wait between
attempts.
+ """
+
+ def __init__(
+ self,
+ job_id: str | None = None,
+ max_retries: int = 10,
+ aws_conn_id: str | None = "aws_default",
+ region_name: str | None = None,
+ poll_interval: int = 30,
+ ):
+ super().__init__()
+ self.job_id = job_id
+ self.max_retries = max_retries
+ self.aws_conn_id = aws_conn_id
+ self.region_name = region_name
+ self.poll_interval = poll_interval
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serializes BatchOperatorTrigger arguments and classpath."""
+ return (
+ "airflow.providers.amazon.aws.triggers.batch.BatchOperatorTrigger",
+ {
+ "job_id": self.job_id,
+ "max_retries": self.max_retries,
+ "aws_conn_id": self.aws_conn_id,
+ "region_name": self.region_name,
+ "poll_interval": self.poll_interval,
+ },
+ )
+
+ @cached_property
+ def hook(self) -> BatchClientHook:
+ return BatchClientHook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name)
+
+ async def run(self):
+
+ async with self.hook.async_conn as client:
+ waiter = self.hook.get_waiter("batch_job_complete",
deferrable=True, client=client)
+ attempt = 0
+ while attempt < self.max_retries:
+ attempt = attempt + 1
+ try:
+ await waiter.wait(
+ jobs=[self.job_id],
+ WaiterConfig={
+ "Delay": self.poll_interval,
+ "MaxAttempts": 1,
+ },
+ )
+ break
+ except WaiterError as error:
+ if "terminal failure" in str(error):
+ yield TriggerEvent(
+ {"status": "failure", "message": f"Delete Cluster
Failed: {error}"}
+ )
+ break
+ self.log.info(
+ "Job status is %s. Retrying attempt %s/%s",
+ error.last_response["jobs"][0]["status"],
+ attempt,
+ self.max_retries,
+ )
+ await asyncio.sleep(int(self.poll_interval))
+
+ if attempt >= self.max_retries:
+ yield TriggerEvent({"status": "failure", "message": "Job Failed -
max attempts reached."})
+ else:
+ yield TriggerEvent({"status": "success", "job_id": self.job_id})
diff --git a/airflow/providers/amazon/aws/waiters/batch.json
b/airflow/providers/amazon/aws/waiters/batch.json
new file mode 100644
index 0000000000..fa9752ea14
--- /dev/null
+++ b/airflow/providers/amazon/aws/waiters/batch.json
@@ -0,0 +1,25 @@
+{
+ "version": 2,
+ "waiters": {
+ "batch_job_complete": {
+ "delay": 300,
+ "operation": "DescribeJobs",
+ "maxAttempts": 100,
+ "description": "Wait until job is SUCCEEDED",
+ "acceptors": [
+ {
+ "argument": "jobs[].status",
+ "expected": "SUCCEEDED",
+ "matcher": "pathAll",
+ "state": "success"
+ },
+ {
+ "argument": "jobs[].status",
+ "expected": "FAILED",
+ "matcher": "pathAll",
+ "state": "failed"
+ }
+ ]
+ }
+ }
+}
diff --git a/tests/providers/amazon/aws/operators/test_batch.py
b/tests/providers/amazon/aws/operators/test_batch.py
index 42f7fae86c..f559424dff 100644
--- a/tests/providers/amazon/aws/operators/test_batch.py
+++ b/tests/providers/amazon/aws/operators/test_batch.py
@@ -22,11 +22,13 @@ from unittest.mock import patch
import pytest
-from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
+from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning, TaskDeferred
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
from airflow.providers.amazon.aws.operators.batch import
BatchCreateComputeEnvironmentOperator, BatchOperator
# Use dummy AWS credentials
+from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger
+
AWS_REGION = "eu-west-1"
AWS_ACCESS_KEY_ID = "airflow_dummy_key"
AWS_SECRET_ACCESS_KEY = "airflow_dummy_secret"
@@ -256,6 +258,21 @@ class TestBatchOperator:
container_overrides={"a": "b"},
)
+
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type")
+ def test_defer_if_deferrable_param_set(self, mock_client):
+ batch = BatchOperator(
+ task_id="task",
+ job_name=JOB_NAME,
+ job_queue="queue",
+ job_definition="hello-world",
+ do_xcom_push=False,
+ deferrable=True,
+ )
+
+ with pytest.raises(TaskDeferred) as exc:
+ batch.execute(context=None)
+ assert isinstance(exc.value.trigger, BatchOperatorTrigger), "Trigger
is not a BatchOperatorTrigger"
+
class TestBatchCreateComputeEnvironmentOperator:
@mock.patch.object(BatchClientHook, "client")
diff --git a/tests/providers/amazon/aws/triggers/test_batch.py
b/tests/providers/amazon/aws/triggers/test_batch.py
new file mode 100644
index 0000000000..54a8765ab4
--- /dev/null
+++ b/tests/providers/amazon/aws/triggers/test_batch.py
@@ -0,0 +1,69 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger
+from airflow.triggers.base import TriggerEvent
+from tests.providers.amazon.aws.utils.compat import AsyncMock, async_mock
+
+BATCH_JOB_ID = "job_id"
+POLL_INTERVAL = 5
+MAX_ATTEMPT = 5
+AWS_CONN_ID = "aws_batch_job_conn"
+AWS_REGION = "us-east-2"
+
+
+class TestBatchOperatorTrigger:
+ def test_batch_operator_trigger_serialize(self):
+ batch_trigger = BatchOperatorTrigger(
+ job_id=BATCH_JOB_ID,
+ poll_interval=POLL_INTERVAL,
+ max_retries=MAX_ATTEMPT,
+ aws_conn_id=AWS_CONN_ID,
+ region_name=AWS_REGION,
+ )
+ class_path, args = batch_trigger.serialize()
+ assert class_path ==
"airflow.providers.amazon.aws.triggers.batch.BatchOperatorTrigger"
+ assert args["job_id"] == BATCH_JOB_ID
+ assert args["poll_interval"] == POLL_INTERVAL
+ assert args["max_retries"] == MAX_ATTEMPT
+ assert args["aws_conn_id"] == AWS_CONN_ID
+ assert args["region_name"] == AWS_REGION
+
+ @pytest.mark.asyncio
+
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_waiter")
+
@async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.async_conn")
+ async def test_batch_job_trigger_run(self, mock_async_conn,
mock_get_waiter):
+ mock = async_mock.MagicMock()
+ mock_async_conn.__aenter__.return_value = mock
+
+ mock_get_waiter().wait = AsyncMock()
+
+ batch_trigger = BatchOperatorTrigger(
+ job_id=BATCH_JOB_ID,
+ poll_interval=POLL_INTERVAL,
+ max_retries=MAX_ATTEMPT,
+ aws_conn_id=AWS_CONN_ID,
+ region_name=AWS_REGION,
+ )
+
+ generator = batch_trigger.run()
+ response = await generator.asend(None)
+
+ assert response == TriggerEvent({"status": "success", "job_id":
BATCH_JOB_ID})