This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3193857376 Add Deferrable Mode for EC2StateSensor (#31130)
3193857376 is described below

commit 3193857376bc2c8cd2eb133017be1e8cbcaa8405
Author: Syed Hussaain <[email protected]>
AuthorDate: Sat May 13 12:49:43 2023 -0700

    Add Deferrable Mode for EC2StateSensor (#31130)
---
 airflow/providers/amazon/aws/hooks/ec2.py          |   5 +
 airflow/providers/amazon/aws/sensors/ec2.py        |  27 +++++-
 .../amazon/aws/{sensors => triggers}/ec2.py        |  66 ++++++-------
 tests/providers/amazon/aws/sensors/test_ec2.py     |  19 ++++
 tests/providers/amazon/aws/triggers/test_ec2.py    | 103 +++++++++++++++++++++
 5 files changed, 188 insertions(+), 32 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/ec2.py 
b/airflow/providers/amazon/aws/hooks/ec2.py
index 25fc7bf18f..91b79aa8a3 100644
--- a/airflow/providers/amazon/aws/hooks/ec2.py
+++ b/airflow/providers/amazon/aws/hooks/ec2.py
@@ -166,6 +166,11 @@ class EC2Hook(AwsBaseHook):
         """
         return [instance["InstanceId"] for instance in 
self.get_instances(filters=filters)]
 
+    async def get_instance_state_async(self, instance_id: str) -> str:
+        async with self.async_conn as client:
+            response = await 
client.describe_instances(InstanceIds=[instance_id])
+            return response["Reservations"][0]["Instances"][0]["State"]["Name"]
+
     def get_instance_state(self, instance_id: str) -> str:
         """
         Get EC2 instance state by id and return it.
diff --git a/airflow/providers/amazon/aws/sensors/ec2.py 
b/airflow/providers/amazon/aws/sensors/ec2.py
index 4377a26444..5ce49c2c0a 100644
--- a/airflow/providers/amazon/aws/sensors/ec2.py
+++ b/airflow/providers/amazon/aws/sensors/ec2.py
@@ -17,10 +17,12 @@
 # under the License.
 from __future__ import annotations
 
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
 
 from airflow.compat.functools import cached_property
+from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
+from airflow.providers.amazon.aws.triggers.ec2 import EC2StateSensorTrigger
 from airflow.sensors.base import BaseSensorOperator
 
 if TYPE_CHECKING:
@@ -39,6 +41,7 @@ class EC2InstanceStateSensor(BaseSensorOperator):
     :param target_state: target state of instance
     :param instance_id: id of the AWS EC2 instance
     :param region_name: (optional) aws region name associated with the client
+    :param deferrable: if True, the sensor will run in deferrable mode
     """
 
     template_fields: Sequence[str] = ("target_state", "instance_id", 
"region_name")
@@ -53,6 +56,7 @@ class EC2InstanceStateSensor(BaseSensorOperator):
         instance_id: str,
         aws_conn_id: str = "aws_default",
         region_name: str | None = None,
+        deferrable: bool = False,
         **kwargs,
     ):
         if target_state not in self.valid_states:
@@ -62,6 +66,22 @@ class EC2InstanceStateSensor(BaseSensorOperator):
         self.instance_id = instance_id
         self.aws_conn_id = aws_conn_id
         self.region_name = region_name
+        self.deferrable = deferrable
+
+    def execute(self, context: Context) -> Any:
+        if self.deferrable:
+            self.defer(
+                trigger=EC2StateSensorTrigger(
+                    instance_id=self.instance_id,
+                    target_state=self.target_state,
+                    aws_conn_id=self.aws_conn_id,
+                    region_name=self.region_name,
+                    poll_interval=int(self.poke_interval),
+                ),
+                method_name="execute_complete",
+            )
+        else:
+            super().execute(context=context)
 
     @cached_property
     def hook(self):
@@ -71,3 +91,8 @@ class EC2InstanceStateSensor(BaseSensorOperator):
         instance_state = 
self.hook.get_instance_state(instance_id=self.instance_id)
         self.log.info("instance state: %s", instance_state)
         return instance_state == self.target_state
+
+    def execute_complete(self, context, event=None):
+        if event["status"] != "success":
+            raise AirflowException(f"Error: {event}")
+        return
diff --git a/airflow/providers/amazon/aws/sensors/ec2.py 
b/airflow/providers/amazon/aws/triggers/ec2.py
similarity index 52%
copy from airflow/providers/amazon/aws/sensors/ec2.py
copy to airflow/providers/amazon/aws/triggers/ec2.py
index 4377a26444..79bae2895d 100644
--- a/airflow/providers/amazon/aws/sensors/ec2.py
+++ b/airflow/providers/amazon/aws/triggers/ec2.py
@@ -1,4 +1,3 @@
-#
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -17,57 +16,62 @@
 # under the License.
 from __future__ import annotations
 
-from typing import TYPE_CHECKING, Sequence
+import asyncio
+from typing import Any
 
 from airflow.compat.functools import cached_property
 from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
-from airflow.sensors.base import BaseSensorOperator
-
-if TYPE_CHECKING:
-    from airflow.utils.context import Context
+from airflow.triggers.base import BaseTrigger, TriggerEvent
 
 
-class EC2InstanceStateSensor(BaseSensorOperator):
+class EC2StateSensorTrigger(BaseTrigger):
     """
-    Check the state of the AWS EC2 instance until
-    state of the instance become equal to the target state.
-
-    .. seealso::
-        For more information on how to use this sensor, take a look at the 
guide:
-        :ref:`howto/sensor:EC2InstanceStateSensor`
+    Trigger for EC2StateSensor. The Trigger polls the EC2 instance, and yields 
a TriggerEvent once
+    the state of the instance matches the `target_state`.
 
-    :param target_state: target state of instance
     :param instance_id: id of the AWS EC2 instance
+    :param target_state: target state of instance
+    :param aws_conn_id: aws connection to use
     :param region_name: (optional) aws region name associated with the client
+    :param poll_interval: number of seconds to wait before attempting the next 
poll
     """
 
-    template_fields: Sequence[str] = ("target_state", "instance_id", 
"region_name")
-    ui_color = "#cc8811"
-    ui_fgcolor = "#ffffff"
-    valid_states = ["running", "stopped", "terminated"]
-
     def __init__(
         self,
-        *,
-        target_state: str,
         instance_id: str,
+        target_state: str,
         aws_conn_id: str = "aws_default",
         region_name: str | None = None,
-        **kwargs,
+        poll_interval: int = 60,
     ):
-        if target_state not in self.valid_states:
-            raise ValueError(f"Invalid target_state: {target_state}")
-        super().__init__(**kwargs)
-        self.target_state = target_state
         self.instance_id = instance_id
+        self.target_state = target_state
         self.aws_conn_id = aws_conn_id
         self.region_name = region_name
+        self.poll_interval = poll_interval
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        return (
+            "airflow.providers.amazon.aws.triggers.ec2.EC2StateSensorTrigger",
+            {
+                "instance_id": self.instance_id,
+                "target_state": self.target_state,
+                "aws_conn_id": self.aws_conn_id,
+                "region_name": self.region_name,
+                "poll_interval": self.poll_interval,
+            },
+        )
 
     @cached_property
     def hook(self):
-        return EC2Hook(aws_conn_id=self.aws_conn_id, 
region_name=self.region_name)
+        return EC2Hook(aws_conn_id=self.aws_conn_id, 
region_name=self.region_name, api_type="client_type")
 
-    def poke(self, context: Context):
-        instance_state = 
self.hook.get_instance_state(instance_id=self.instance_id)
-        self.log.info("instance state: %s", instance_state)
-        return instance_state == self.target_state
+    async def run(self):
+        while True:
+            instance_state = await 
self.hook.get_instance_state_async(instance_id=self.instance_id)
+            self.log.info("instance state: %s", instance_state)
+            if instance_state == self.target_state:
+                yield TriggerEvent({"status": "success", "message": "target 
state met"})
+                break
+            else:
+                await asyncio.sleep(self.poll_interval)
diff --git a/tests/providers/amazon/aws/sensors/test_ec2.py 
b/tests/providers/amazon/aws/sensors/test_ec2.py
index acc626eb76..3f5a6f09f7 100644
--- a/tests/providers/amazon/aws/sensors/test_ec2.py
+++ b/tests/providers/amazon/aws/sensors/test_ec2.py
@@ -20,6 +20,7 @@ from __future__ import annotations
 import pytest
 from moto import mock_ec2
 
+from airflow.exceptions import TaskDeferred
 from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
 from airflow.providers.amazon.aws.sensors.ec2 import EC2InstanceStateSensor
 
@@ -126,3 +127,21 @@ class TestEC2InstanceStateSensor:
         ec2_hook.get_instance(instance_id=instance_id).terminate()
         # assert instance state is terminated
         assert stop_sensor.poke(None)
+
+    @mock_ec2
+    def test_deferrable(self):
+        # create instance
+        ec2_hook = EC2Hook()
+        instance_id = self._create_instance(ec2_hook)
+        # start instance
+        ec2_hook.get_instance(instance_id=instance_id).start()
+
+        # stop sensor, waits until ec2 instance state became terminated
+        deferrable_sensor = EC2InstanceStateSensor(
+            task_id="deferrable_sensor",
+            target_state="terminated",
+            instance_id=instance_id,
+            deferrable=True,
+        )
+        with pytest.raises(TaskDeferred):
+            deferrable_sensor.execute(context=None)
diff --git a/tests/providers/amazon/aws/triggers/test_ec2.py 
b/tests/providers/amazon/aws/triggers/test_ec2.py
new file mode 100644
index 0000000000..51943079d1
--- /dev/null
+++ b/tests/providers/amazon/aws/triggers/test_ec2.py
@@ -0,0 +1,103 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import sys
+
+import pytest
+
+from airflow.providers.amazon.aws.triggers.ec2 import EC2StateSensorTrigger
+from airflow.triggers.base import TriggerEvent
+
+if sys.version_info < (3, 8):
+    from asynctest import CoroutineMock as AsyncMock, mock as async_mock
+else:
+    from unittest import mock as async_mock
+    from unittest.mock import AsyncMock
+
+TEST_INSTANCE_ID = "test-instance-id"
+TEST_TARGET_STATE = "test_state"
+TEST_CONN_ID = "test_conn_id"
+TEST_REGION_NAME = "test-region"
+TEST_POLL_INTERVAL = 100
+
+
+class TestEC2StateSensorTrigger:
+    def test_ec2_state_sensor_trigger_serialize(self):
+        test_ec2_state_sensor = EC2StateSensorTrigger(
+            instance_id=TEST_INSTANCE_ID,
+            target_state=TEST_TARGET_STATE,
+            aws_conn_id=TEST_CONN_ID,
+            region_name=TEST_REGION_NAME,
+            poll_interval=TEST_POLL_INTERVAL,
+        )
+
+        class_path, args = test_ec2_state_sensor.serialize()
+        assert class_path == 
"airflow.providers.amazon.aws.triggers.ec2.EC2StateSensorTrigger"
+        assert args["instance_id"] == TEST_INSTANCE_ID
+        assert args["target_state"] == TEST_TARGET_STATE
+        assert args["aws_conn_id"] == TEST_CONN_ID
+        assert args["region_name"] == TEST_REGION_NAME
+        assert args["poll_interval"] == TEST_POLL_INTERVAL
+
+    @pytest.mark.asyncio
+    
@async_mock.patch("airflow.providers.amazon.aws.hooks.ec2.EC2Hook.get_instance_state_async")
+    
@async_mock.patch("airflow.providers.amazon.aws.hooks.ec2.EC2Hook.async_conn")
+    async def test_ec2_state_sensor_run(self, mock_async_conn, 
mock_get_instance_state_async):
+        mock = AsyncMock()
+        mock_async_conn.__aenter__.return_value = mock
+        mock_get_instance_state_async.return_value = TEST_TARGET_STATE
+
+        test_ec2_state_sensor = EC2StateSensorTrigger(
+            instance_id=TEST_INSTANCE_ID,
+            target_state=TEST_TARGET_STATE,
+            aws_conn_id=TEST_CONN_ID,
+            region_name=TEST_REGION_NAME,
+            poll_interval=TEST_POLL_INTERVAL,
+        )
+
+        generator = test_ec2_state_sensor.run()
+        response = await generator.asend(None)
+
+        assert response == TriggerEvent({"status": "success", "message": 
"target state met"})
+
+    @pytest.mark.asyncio
+    @async_mock.patch("asyncio.sleep")
+    
@async_mock.patch("airflow.providers.amazon.aws.hooks.ec2.EC2Hook.get_instance_state_async")
+    
@async_mock.patch("airflow.providers.amazon.aws.hooks.ec2.EC2Hook.async_conn")
+    async def test_ec2_state_sensor_run_multiple(
+        self, mock_async_conn, mock_get_instance_state_async, mock_sleep
+    ):
+        mock = AsyncMock()
+        mock_async_conn.__aenter__.return_value = mock
+        mock_get_instance_state_async.side_effect = ["test-state", 
TEST_TARGET_STATE]
+        mock_sleep.return_value = True
+
+        test_ec2_state_sensor = EC2StateSensorTrigger(
+            instance_id=TEST_INSTANCE_ID,
+            target_state=TEST_TARGET_STATE,
+            aws_conn_id=TEST_CONN_ID,
+            region_name=TEST_REGION_NAME,
+            poll_interval=TEST_POLL_INTERVAL,
+        )
+
+        generator = test_ec2_state_sensor.run()
+        response = await generator.asend(None)
+
+        assert mock_get_instance_state_async.call_count == 2
+
+        assert response == TriggerEvent({"status": "success", "message": 
"target state met"})

Reply via email to