This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new af28b41 Add sensor for AWS Batch (#19850) (#19885)
af28b41 is described below
commit af28b4190316401c9dfec6108d22b0525974eadb
Author: Yeshwanth Balachander <[email protected]>
AuthorDate: Sun Dec 5 13:52:30 2021 -0800
Add sensor for AWS Batch (#19850) (#19885)
* Add sensor for AWS Batch (#19850)
Adds a sensor implementation to ask for the status of an
AWS Batch job. The sensor will enable DAGs to wait for the
batch job to reach a terminal state before proceeding to the
downstream tasks.
---
airflow/providers/amazon/aws/hooks/batch_client.py | 22 ++++--
airflow/providers/amazon/aws/sensors/batch.py | 78 ++++++++++++++++++++++
airflow/providers/amazon/provider.yaml | 3 +
tests/providers/amazon/aws/sensors/test_batch.py | 75 +++++++++++++++++++++
4 files changed, 172 insertions(+), 6 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/batch_client.py
b/airflow/providers/amazon/aws/hooks/batch_client.py
index da78ecd..ae74e73 100644
--- a/airflow/providers/amazon/aws/hooks/batch_client.py
+++ b/airflow/providers/amazon/aws/hooks/batch_client.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""
A client for AWS batch services
@@ -195,6 +194,17 @@ class AwsBatchClientHook(AwsBaseHook):
DEFAULT_DELAY_MIN = 1
DEFAULT_DELAY_MAX = 10
+ FAILURE_STATE = 'FAILED'
+ SUCCESS_STATE = 'SUCCEEDED'
+ RUNNING_STATE = 'RUNNING'
+ INTERMEDIATE_STATES = (
+ 'SUBMITTED',
+ 'PENDING',
+ 'RUNNABLE',
+ 'STARTING',
+ RUNNING_STATE,
+ )
+
def __init__(
self, *args, max_retries: Optional[int] = None, status_retries:
Optional[int] = None, **kwargs
) -> None:
@@ -245,14 +255,14 @@ class AwsBatchClientHook(AwsBaseHook):
job = self.get_job_description(job_id)
job_status = job.get("status")
- if job_status == "SUCCEEDED":
+ if job_status == self.SUCCESS_STATE:
self.log.info("AWS batch job (%s) succeeded: %s", job_id, job)
return True
- if job_status == "FAILED":
+ if job_status == self.FAILURE_STATE:
raise AirflowException(f"AWS Batch job ({job_id}) failed: {job}")
- if job_status in ["SUBMITTED", "PENDING", "RUNNABLE", "STARTING",
"RUNNING"]:
+ if job_status in self.INTERMEDIATE_STATES:
raise AirflowException(f"AWS Batch job ({job_id}) is not complete:
{job}")
raise AirflowException(f"AWS Batch job ({job_id}) has unknown status:
{job}")
@@ -295,7 +305,7 @@ class AwsBatchClientHook(AwsBaseHook):
:raises: AirflowException
"""
self.delay(delay)
- running_status = ["RUNNING", "SUCCEEDED", "FAILED"]
+ running_status = [self.RUNNING_STATE, self.SUCCESS_STATE,
self.FAILURE_STATE]
self.poll_job_status(job_id, running_status)
def poll_for_job_complete(self, job_id: str, delay: Union[int, float,
None] = None) -> None:
@@ -315,7 +325,7 @@ class AwsBatchClientHook(AwsBaseHook):
:raises: AirflowException
"""
self.delay(delay)
- complete_status = ["SUCCEEDED", "FAILED"]
+ complete_status = [self.SUCCESS_STATE, self.FAILURE_STATE]
self.poll_job_status(job_id, complete_status)
def poll_job_status(self, job_id: str, match_status: List[str]) -> bool:
diff --git a/airflow/providers/amazon/aws/sensors/batch.py
b/airflow/providers/amazon/aws/sensors/batch.py
new file mode 100644
index 0000000..35555cc
--- /dev/null
+++ b/airflow/providers/amazon/aws/sensors/batch.py
@@ -0,0 +1,78 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Dict, Optional
+
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.hooks.batch_client import AwsBatchClientHook
+from airflow.sensors.base import BaseSensorOperator
+
+
+class BatchSensor(BaseSensorOperator):
+ """
+ Asks for the state of the Batch Job execution until it reaches a failure
state or success state.
+ If the job fails, the task will fail.
+
+ :param job_id: Batch job_id to check the state for
+ :type job_id: str
+ :param aws_conn_id: aws connection to use, defaults to 'aws_default'
+ :type aws_conn_id: str
+ """
+
+ template_fields = ['job_id']
+ template_ext = ()
+ ui_color = '#66c3ff'
+
+ def __init__(
+ self,
+ *,
+ job_id: str,
+ aws_conn_id: str = 'aws_default',
+ region_name: Optional[str] = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.job_id = job_id
+ self.aws_conn_id = aws_conn_id
+ self.region_name = region_name
+ self.hook: Optional[AwsBatchClientHook] = None
+
+ def poke(self, context: Dict) -> bool:
+ job_description = self.get_hook().get_job_description(self.job_id)
+ state = job_description['status']
+
+ if state == AwsBatchClientHook.SUCCESS_STATE:
+ return True
+
+ if state in AwsBatchClientHook.INTERMEDIATE_STATES:
+ return False
+
+ if state == AwsBatchClientHook.FAILURE_STATE:
+ raise AirflowException(f'Batch sensor failed. AWS Batch job
status: {state}')
+
+ raise AirflowException(f'Batch sensor failed. Unknown AWS Batch job
status: {state}')
+
+ def get_hook(self) -> AwsBatchClientHook:
+ """Create and return a AwsBatchClientHook"""
+ if self.hook:
+ return self.hook
+
+ self.hook = AwsBatchClientHook(
+ aws_conn_id=self.aws_conn_id,
+ region_name=self.region_name,
+ )
+ return self.hook
diff --git a/airflow/providers/amazon/provider.yaml
b/airflow/providers/amazon/provider.yaml
index 327efac..f7c392f 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -251,6 +251,9 @@ sensors:
- integration-name: Amazon Athena
python-modules:
- airflow.providers.amazon.aws.sensors.athena
+ - integration-name: AWS Batch
+ python-modules:
+ - airflow.providers.amazon.aws.sensors.batch
- integration-name: Amazon CloudFormation
python-modules:
- airflow.providers.amazon.aws.sensors.cloud_formation
diff --git a/tests/providers/amazon/aws/sensors/test_batch.py
b/tests/providers/amazon/aws/sensors/test_batch.py
new file mode 100644
index 0000000..fd03451
--- /dev/null
+++ b/tests/providers/amazon/aws/sensors/test_batch.py
@@ -0,0 +1,75 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+from unittest import mock
+
+from parameterized import parameterized
+
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.hooks.batch_client import AwsBatchClientHook
+from airflow.providers.amazon.aws.sensors.batch import BatchSensor
+
+TASK_ID = 'batch_job_sensor'
+JOB_ID = '8222a1c2-b246-4e19-b1b8-0039bb4407c0'
+
+
+class TestBatchSensor(unittest.TestCase):
+ def setUp(self):
+ self.batch_sensor = BatchSensor(
+ task_id='batch_job_sensor',
+ job_id=JOB_ID,
+ )
+
+ @mock.patch.object(AwsBatchClientHook, 'get_job_description')
+ def test_poke_on_success_state(self, mock_get_job_description):
+ mock_get_job_description.return_value = {'status': 'SUCCEEDED'}
+ self.assertTrue(self.batch_sensor.poke(None))
+ mock_get_job_description.assert_called_once_with(JOB_ID)
+
+ @mock.patch.object(AwsBatchClientHook, 'get_job_description')
+ def test_poke_on_failure_state(self, mock_get_job_description):
+ mock_get_job_description.return_value = {'status': 'FAILED'}
+ with self.assertRaises(AirflowException) as e:
+ self.batch_sensor.poke(None)
+
+ self.assertEqual('Batch sensor failed. AWS Batch job status: FAILED',
str(e.exception))
+ mock_get_job_description.assert_called_once_with(JOB_ID)
+
+ @mock.patch.object(AwsBatchClientHook, 'get_job_description')
+ def test_poke_on_invalid_state(self, mock_get_job_description):
+ mock_get_job_description.return_value = {'status': 'INVALID'}
+ with self.assertRaises(AirflowException) as e:
+ self.batch_sensor.poke(None)
+
+ self.assertEqual('Batch sensor failed. Unknown AWS Batch job status:
INVALID', str(e.exception))
+ mock_get_job_description.assert_called_once_with(JOB_ID)
+
+ @parameterized.expand(
+ [
+ ('SUBMITTED',),
+ ('PENDING',),
+ ('RUNNABLE',),
+ ('STARTING',),
+ ('RUNNING',),
+ ]
+ )
+ @mock.patch.object(AwsBatchClientHook, 'get_job_description')
+ def test_poke_on_intermediate_state(self, job_status,
mock_get_job_description):
+ mock_get_job_description.return_value = {'status': job_status}
+ self.assertFalse(self.batch_sensor.poke(None))
+ mock_get_job_description.assert_called_once_with(JOB_ID)