This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new af28b41  Add sensor for AWS Batch (#19850) (#19885)
af28b41 is described below

commit af28b4190316401c9dfec6108d22b0525974eadb
Author: Yeshwanth Balachander <[email protected]>
AuthorDate: Sun Dec 5 13:52:30 2021 -0800

    Add sensor for AWS Batch (#19850) (#19885)
    
    * Add sensor for AWS Batch (#19850)
    
    Adds a sensor implementation to ask for the status of an
    AWS Batch job. The sensor will enable DAGs to wait for the
    batch job to reach a terminal state before proceeding to the
    downstream tasks.
---
 airflow/providers/amazon/aws/hooks/batch_client.py | 22 ++++--
 airflow/providers/amazon/aws/sensors/batch.py      | 78 ++++++++++++++++++++++
 airflow/providers/amazon/provider.yaml             |  3 +
 tests/providers/amazon/aws/sensors/test_batch.py   | 75 +++++++++++++++++++++
 4 files changed, 172 insertions(+), 6 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/batch_client.py 
b/airflow/providers/amazon/aws/hooks/batch_client.py
index da78ecd..ae74e73 100644
--- a/airflow/providers/amazon/aws/hooks/batch_client.py
+++ b/airflow/providers/amazon/aws/hooks/batch_client.py
@@ -15,7 +15,6 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
 """
 A client for AWS batch services
 
@@ -195,6 +194,17 @@ class AwsBatchClientHook(AwsBaseHook):
     DEFAULT_DELAY_MIN = 1
     DEFAULT_DELAY_MAX = 10
 
+    FAILURE_STATE = 'FAILED'
+    SUCCESS_STATE = 'SUCCEEDED'
+    RUNNING_STATE = 'RUNNING'
+    INTERMEDIATE_STATES = (
+        'SUBMITTED',
+        'PENDING',
+        'RUNNABLE',
+        'STARTING',
+        RUNNING_STATE,
+    )
+
     def __init__(
         self, *args, max_retries: Optional[int] = None, status_retries: 
Optional[int] = None, **kwargs
     ) -> None:
@@ -245,14 +255,14 @@ class AwsBatchClientHook(AwsBaseHook):
         job = self.get_job_description(job_id)
         job_status = job.get("status")
 
-        if job_status == "SUCCEEDED":
+        if job_status == self.SUCCESS_STATE:
             self.log.info("AWS batch job (%s) succeeded: %s", job_id, job)
             return True
 
-        if job_status == "FAILED":
+        if job_status == self.FAILURE_STATE:
             raise AirflowException(f"AWS Batch job ({job_id}) failed: {job}")
 
-        if job_status in ["SUBMITTED", "PENDING", "RUNNABLE", "STARTING", 
"RUNNING"]:
+        if job_status in self.INTERMEDIATE_STATES:
             raise AirflowException(f"AWS Batch job ({job_id}) is not complete: 
{job}")
 
         raise AirflowException(f"AWS Batch job ({job_id}) has unknown status: 
{job}")
@@ -295,7 +305,7 @@ class AwsBatchClientHook(AwsBaseHook):
         :raises: AirflowException
         """
         self.delay(delay)
-        running_status = ["RUNNING", "SUCCEEDED", "FAILED"]
+        running_status = [self.RUNNING_STATE, self.SUCCESS_STATE, 
self.FAILURE_STATE]
         self.poll_job_status(job_id, running_status)
 
     def poll_for_job_complete(self, job_id: str, delay: Union[int, float, 
None] = None) -> None:
@@ -315,7 +325,7 @@ class AwsBatchClientHook(AwsBaseHook):
         :raises: AirflowException
         """
         self.delay(delay)
-        complete_status = ["SUCCEEDED", "FAILED"]
+        complete_status = [self.SUCCESS_STATE, self.FAILURE_STATE]
         self.poll_job_status(job_id, complete_status)
 
     def poll_job_status(self, job_id: str, match_status: List[str]) -> bool:
diff --git a/airflow/providers/amazon/aws/sensors/batch.py 
b/airflow/providers/amazon/aws/sensors/batch.py
new file mode 100644
index 0000000..35555cc
--- /dev/null
+++ b/airflow/providers/amazon/aws/sensors/batch.py
@@ -0,0 +1,78 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Dict, Optional
+
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.hooks.batch_client import AwsBatchClientHook
+from airflow.sensors.base import BaseSensorOperator
+
+
+class BatchSensor(BaseSensorOperator):
+    """
+    Asks for the state of the Batch Job execution until it reaches a failure 
state or success state.
+    If the job fails, the task will fail.
+
+    :param job_id: Batch job_id to check the state for
+    :type job_id: str
+    :param aws_conn_id: aws connection to use, defaults to 'aws_default'
+    :type aws_conn_id: str
+    """
+
+    template_fields = ['job_id']
+    template_ext = ()
+    ui_color = '#66c3ff'
+
+    def __init__(
+        self,
+        *,
+        job_id: str,
+        aws_conn_id: str = 'aws_default',
+        region_name: Optional[str] = None,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        self.job_id = job_id
+        self.aws_conn_id = aws_conn_id
+        self.region_name = region_name
+        self.hook: Optional[AwsBatchClientHook] = None
+
+    def poke(self, context: Dict) -> bool:
+        job_description = self.get_hook().get_job_description(self.job_id)
+        state = job_description['status']
+
+        if state == AwsBatchClientHook.SUCCESS_STATE:
+            return True
+
+        if state in AwsBatchClientHook.INTERMEDIATE_STATES:
+            return False
+
+        if state == AwsBatchClientHook.FAILURE_STATE:
+            raise AirflowException(f'Batch sensor failed. AWS Batch job 
status: {state}')
+
+        raise AirflowException(f'Batch sensor failed. Unknown AWS Batch job 
status: {state}')
+
+    def get_hook(self) -> AwsBatchClientHook:
+        """Create and return a AwsBatchClientHook"""
+        if self.hook:
+            return self.hook
+
+        self.hook = AwsBatchClientHook(
+            aws_conn_id=self.aws_conn_id,
+            region_name=self.region_name,
+        )
+        return self.hook
diff --git a/airflow/providers/amazon/provider.yaml 
b/airflow/providers/amazon/provider.yaml
index 327efac..f7c392f 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -251,6 +251,9 @@ sensors:
   - integration-name: Amazon Athena
     python-modules:
       - airflow.providers.amazon.aws.sensors.athena
+  - integration-name: AWS Batch
+    python-modules:
+      - airflow.providers.amazon.aws.sensors.batch
   - integration-name: Amazon CloudFormation
     python-modules:
       - airflow.providers.amazon.aws.sensors.cloud_formation
diff --git a/tests/providers/amazon/aws/sensors/test_batch.py 
b/tests/providers/amazon/aws/sensors/test_batch.py
new file mode 100644
index 0000000..fd03451
--- /dev/null
+++ b/tests/providers/amazon/aws/sensors/test_batch.py
@@ -0,0 +1,75 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+from unittest import mock
+
+from parameterized import parameterized
+
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.hooks.batch_client import AwsBatchClientHook
+from airflow.providers.amazon.aws.sensors.batch import BatchSensor
+
+TASK_ID = 'batch_job_sensor'
+JOB_ID = '8222a1c2-b246-4e19-b1b8-0039bb4407c0'
+
+
+class TestBatchSensor(unittest.TestCase):
+    def setUp(self):
+        self.batch_sensor = BatchSensor(
+            task_id='batch_job_sensor',
+            job_id=JOB_ID,
+        )
+
+    @mock.patch.object(AwsBatchClientHook, 'get_job_description')
+    def test_poke_on_success_state(self, mock_get_job_description):
+        mock_get_job_description.return_value = {'status': 'SUCCEEDED'}
+        self.assertTrue(self.batch_sensor.poke(None))
+        mock_get_job_description.assert_called_once_with(JOB_ID)
+
+    @mock.patch.object(AwsBatchClientHook, 'get_job_description')
+    def test_poke_on_failure_state(self, mock_get_job_description):
+        mock_get_job_description.return_value = {'status': 'FAILED'}
+        with self.assertRaises(AirflowException) as e:
+            self.batch_sensor.poke(None)
+
+        self.assertEqual('Batch sensor failed. AWS Batch job status: FAILED', 
str(e.exception))
+        mock_get_job_description.assert_called_once_with(JOB_ID)
+
+    @mock.patch.object(AwsBatchClientHook, 'get_job_description')
+    def test_poke_on_invalid_state(self, mock_get_job_description):
+        mock_get_job_description.return_value = {'status': 'INVALID'}
+        with self.assertRaises(AirflowException) as e:
+            self.batch_sensor.poke(None)
+
+        self.assertEqual('Batch sensor failed. Unknown AWS Batch job status: 
INVALID', str(e.exception))
+        mock_get_job_description.assert_called_once_with(JOB_ID)
+
+    @parameterized.expand(
+        [
+            ('SUBMITTED',),
+            ('PENDING',),
+            ('RUNNABLE',),
+            ('STARTING',),
+            ('RUNNING',),
+        ]
+    )
+    @mock.patch.object(AwsBatchClientHook, 'get_job_description')
+    def test_poke_on_intermediate_state(self, job_status, 
mock_get_job_description):
+        mock_get_job_description.return_value = {'status': job_status}
+        self.assertFalse(self.batch_sensor.poke(None))
+        mock_get_job_description.assert_called_once_with(JOB_ID)

Reply via email to