This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 3193857376 Add Deferrable Mode for EC2StateSensor (#31130)
3193857376 is described below
commit 3193857376bc2c8cd2eb133017be1e8cbcaa8405
Author: Syed Hussaain <[email protected]>
AuthorDate: Sat May 13 12:49:43 2023 -0700
Add Deferrable Mode for EC2StateSensor (#31130)
---
airflow/providers/amazon/aws/hooks/ec2.py | 5 +
airflow/providers/amazon/aws/sensors/ec2.py | 27 +++++-
.../amazon/aws/{sensors => triggers}/ec2.py | 66 ++++++-------
tests/providers/amazon/aws/sensors/test_ec2.py | 19 ++++
tests/providers/amazon/aws/triggers/test_ec2.py | 103 +++++++++++++++++++++
5 files changed, 188 insertions(+), 32 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/ec2.py
b/airflow/providers/amazon/aws/hooks/ec2.py
index 25fc7bf18f..91b79aa8a3 100644
--- a/airflow/providers/amazon/aws/hooks/ec2.py
+++ b/airflow/providers/amazon/aws/hooks/ec2.py
@@ -166,6 +166,11 @@ class EC2Hook(AwsBaseHook):
"""
return [instance["InstanceId"] for instance in
self.get_instances(filters=filters)]
+ async def get_instance_state_async(self, instance_id: str) -> str:
+ async with self.async_conn as client:
+ response = await
client.describe_instances(InstanceIds=[instance_id])
+ return response["Reservations"][0]["Instances"][0]["State"]["Name"]
+
def get_instance_state(self, instance_id: str) -> str:
"""
Get EC2 instance state by id and return it.
diff --git a/airflow/providers/amazon/aws/sensors/ec2.py
b/airflow/providers/amazon/aws/sensors/ec2.py
index 4377a26444..5ce49c2c0a 100644
--- a/airflow/providers/amazon/aws/sensors/ec2.py
+++ b/airflow/providers/amazon/aws/sensors/ec2.py
@@ -17,10 +17,12 @@
# under the License.
from __future__ import annotations
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
from airflow.compat.functools import cached_property
+from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
+from airflow.providers.amazon.aws.triggers.ec2 import EC2StateSensorTrigger
from airflow.sensors.base import BaseSensorOperator
if TYPE_CHECKING:
@@ -39,6 +41,7 @@ class EC2InstanceStateSensor(BaseSensorOperator):
:param target_state: target state of instance
:param instance_id: id of the AWS EC2 instance
:param region_name: (optional) aws region name associated with the client
+ :param deferrable: if True, the sensor will run in deferrable mode
"""
template_fields: Sequence[str] = ("target_state", "instance_id",
"region_name")
@@ -53,6 +56,7 @@ class EC2InstanceStateSensor(BaseSensorOperator):
instance_id: str,
aws_conn_id: str = "aws_default",
region_name: str | None = None,
+ deferrable: bool = False,
**kwargs,
):
if target_state not in self.valid_states:
@@ -62,6 +66,22 @@ class EC2InstanceStateSensor(BaseSensorOperator):
self.instance_id = instance_id
self.aws_conn_id = aws_conn_id
self.region_name = region_name
+ self.deferrable = deferrable
+
+ def execute(self, context: Context) -> Any:
+ if self.deferrable:
+ self.defer(
+ trigger=EC2StateSensorTrigger(
+ instance_id=self.instance_id,
+ target_state=self.target_state,
+ aws_conn_id=self.aws_conn_id,
+ region_name=self.region_name,
+ poll_interval=int(self.poke_interval),
+ ),
+ method_name="execute_complete",
+ )
+ else:
+ super().execute(context=context)
@cached_property
def hook(self):
@@ -71,3 +91,8 @@ class EC2InstanceStateSensor(BaseSensorOperator):
instance_state =
self.hook.get_instance_state(instance_id=self.instance_id)
self.log.info("instance state: %s", instance_state)
return instance_state == self.target_state
+
+ def execute_complete(self, context, event=None):
+ if event["status"] != "success":
+ raise AirflowException(f"Error: {event}")
+ return
diff --git a/airflow/providers/amazon/aws/sensors/ec2.py
b/airflow/providers/amazon/aws/triggers/ec2.py
similarity index 52%
copy from airflow/providers/amazon/aws/sensors/ec2.py
copy to airflow/providers/amazon/aws/triggers/ec2.py
index 4377a26444..79bae2895d 100644
--- a/airflow/providers/amazon/aws/sensors/ec2.py
+++ b/airflow/providers/amazon/aws/triggers/ec2.py
@@ -1,4 +1,3 @@
-#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -17,57 +16,62 @@
# under the License.
from __future__ import annotations
-from typing import TYPE_CHECKING, Sequence
+import asyncio
+from typing import Any
from airflow.compat.functools import cached_property
from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
-from airflow.sensors.base import BaseSensorOperator
-
-if TYPE_CHECKING:
- from airflow.utils.context import Context
+from airflow.triggers.base import BaseTrigger, TriggerEvent
-class EC2InstanceStateSensor(BaseSensorOperator):
+class EC2StateSensorTrigger(BaseTrigger):
"""
- Check the state of the AWS EC2 instance until
- state of the instance become equal to the target state.
-
- .. seealso::
- For more information on how to use this sensor, take a look at the
guide:
- :ref:`howto/sensor:EC2InstanceStateSensor`
+ Trigger for EC2StateSensor. The Trigger polls the EC2 instance, and yields
a TriggerEvent once
+ the state of the instance matches the `target_state`.
- :param target_state: target state of instance
:param instance_id: id of the AWS EC2 instance
+ :param target_state: target state of instance
+ :param aws_conn_id: aws connection to use
:param region_name: (optional) aws region name associated with the client
+ :param poll_interval: number of seconds to wait before attempting the next
poll
"""
- template_fields: Sequence[str] = ("target_state", "instance_id",
"region_name")
- ui_color = "#cc8811"
- ui_fgcolor = "#ffffff"
- valid_states = ["running", "stopped", "terminated"]
-
def __init__(
self,
- *,
- target_state: str,
instance_id: str,
+ target_state: str,
aws_conn_id: str = "aws_default",
region_name: str | None = None,
- **kwargs,
+ poll_interval: int = 60,
):
- if target_state not in self.valid_states:
- raise ValueError(f"Invalid target_state: {target_state}")
- super().__init__(**kwargs)
- self.target_state = target_state
self.instance_id = instance_id
+ self.target_state = target_state
self.aws_conn_id = aws_conn_id
self.region_name = region_name
+ self.poll_interval = poll_interval
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ return (
+ "airflow.providers.amazon.aws.triggers.ec2.EC2StateSensorTrigger",
+ {
+ "instance_id": self.instance_id,
+ "target_state": self.target_state,
+ "aws_conn_id": self.aws_conn_id,
+ "region_name": self.region_name,
+ "poll_interval": self.poll_interval,
+ },
+ )
@cached_property
def hook(self):
- return EC2Hook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name)
+ return EC2Hook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name, api_type="client_type")
- def poke(self, context: Context):
- instance_state =
self.hook.get_instance_state(instance_id=self.instance_id)
- self.log.info("instance state: %s", instance_state)
- return instance_state == self.target_state
+ async def run(self):
+ while True:
+ instance_state = await
self.hook.get_instance_state_async(instance_id=self.instance_id)
+ self.log.info("instance state: %s", instance_state)
+ if instance_state == self.target_state:
+ yield TriggerEvent({"status": "success", "message": "target
state met"})
+ break
+ else:
+ await asyncio.sleep(self.poll_interval)
diff --git a/tests/providers/amazon/aws/sensors/test_ec2.py
b/tests/providers/amazon/aws/sensors/test_ec2.py
index acc626eb76..3f5a6f09f7 100644
--- a/tests/providers/amazon/aws/sensors/test_ec2.py
+++ b/tests/providers/amazon/aws/sensors/test_ec2.py
@@ -20,6 +20,7 @@ from __future__ import annotations
import pytest
from moto import mock_ec2
+from airflow.exceptions import TaskDeferred
from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
from airflow.providers.amazon.aws.sensors.ec2 import EC2InstanceStateSensor
@@ -126,3 +127,21 @@ class TestEC2InstanceStateSensor:
ec2_hook.get_instance(instance_id=instance_id).terminate()
# assert instance state is terminated
assert stop_sensor.poke(None)
+
+ @mock_ec2
+ def test_deferrable(self):
+ # create instance
+ ec2_hook = EC2Hook()
+ instance_id = self._create_instance(ec2_hook)
+ # start instance
+ ec2_hook.get_instance(instance_id=instance_id).start()
+
+ # stop sensor, waits until ec2 instance state became terminated
+ deferrable_sensor = EC2InstanceStateSensor(
+ task_id="deferrable_sensor",
+ target_state="terminated",
+ instance_id=instance_id,
+ deferrable=True,
+ )
+ with pytest.raises(TaskDeferred):
+ deferrable_sensor.execute(context=None)
diff --git a/tests/providers/amazon/aws/triggers/test_ec2.py
b/tests/providers/amazon/aws/triggers/test_ec2.py
new file mode 100644
index 0000000000..51943079d1
--- /dev/null
+++ b/tests/providers/amazon/aws/triggers/test_ec2.py
@@ -0,0 +1,103 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import sys
+
+import pytest
+
+from airflow.providers.amazon.aws.triggers.ec2 import EC2StateSensorTrigger
+from airflow.triggers.base import TriggerEvent
+
+if sys.version_info < (3, 8):
+ from asynctest import CoroutineMock as AsyncMock, mock as async_mock
+else:
+ from unittest import mock as async_mock
+ from unittest.mock import AsyncMock
+
+TEST_INSTANCE_ID = "test-instance-id"
+TEST_TARGET_STATE = "test_state"
+TEST_CONN_ID = "test_conn_id"
+TEST_REGION_NAME = "test-region"
+TEST_POLL_INTERVAL = 100
+
+
+class TestEC2StateSensorTrigger:
+ def test_ec2_state_sensor_trigger_serialize(self):
+ test_ec2_state_sensor = EC2StateSensorTrigger(
+ instance_id=TEST_INSTANCE_ID,
+ target_state=TEST_TARGET_STATE,
+ aws_conn_id=TEST_CONN_ID,
+ region_name=TEST_REGION_NAME,
+ poll_interval=TEST_POLL_INTERVAL,
+ )
+
+ class_path, args = test_ec2_state_sensor.serialize()
+ assert class_path ==
"airflow.providers.amazon.aws.triggers.ec2.EC2StateSensorTrigger"
+ assert args["instance_id"] == TEST_INSTANCE_ID
+ assert args["target_state"] == TEST_TARGET_STATE
+ assert args["aws_conn_id"] == TEST_CONN_ID
+ assert args["region_name"] == TEST_REGION_NAME
+ assert args["poll_interval"] == TEST_POLL_INTERVAL
+
+ @pytest.mark.asyncio
+
@async_mock.patch("airflow.providers.amazon.aws.hooks.ec2.EC2Hook.get_instance_state_async")
+
@async_mock.patch("airflow.providers.amazon.aws.hooks.ec2.EC2Hook.async_conn")
+ async def test_ec2_state_sensor_run(self, mock_async_conn,
mock_get_instance_state_async):
+ mock = AsyncMock()
+ mock_async_conn.__aenter__.return_value = mock
+ mock_get_instance_state_async.return_value = TEST_TARGET_STATE
+
+ test_ec2_state_sensor = EC2StateSensorTrigger(
+ instance_id=TEST_INSTANCE_ID,
+ target_state=TEST_TARGET_STATE,
+ aws_conn_id=TEST_CONN_ID,
+ region_name=TEST_REGION_NAME,
+ poll_interval=TEST_POLL_INTERVAL,
+ )
+
+ generator = test_ec2_state_sensor.run()
+ response = await generator.asend(None)
+
+ assert response == TriggerEvent({"status": "success", "message":
"target state met"})
+
+ @pytest.mark.asyncio
+ @async_mock.patch("asyncio.sleep")
+
@async_mock.patch("airflow.providers.amazon.aws.hooks.ec2.EC2Hook.get_instance_state_async")
+
@async_mock.patch("airflow.providers.amazon.aws.hooks.ec2.EC2Hook.async_conn")
+ async def test_ec2_state_sensor_run_multiple(
+ self, mock_async_conn, mock_get_instance_state_async, mock_sleep
+ ):
+ mock = AsyncMock()
+ mock_async_conn.__aenter__.return_value = mock
+ mock_get_instance_state_async.side_effect = ["test-state",
TEST_TARGET_STATE]
+ mock_sleep.return_value = True
+
+ test_ec2_state_sensor = EC2StateSensorTrigger(
+ instance_id=TEST_INSTANCE_ID,
+ target_state=TEST_TARGET_STATE,
+ aws_conn_id=TEST_CONN_ID,
+ region_name=TEST_REGION_NAME,
+ poll_interval=TEST_POLL_INTERVAL,
+ )
+
+ generator = test_ec2_state_sensor.run()
+ response = await generator.asend(None)
+
+ assert mock_get_instance_state_async.call_count == 2
+
+ assert response == TriggerEvent({"status": "success", "message":
"target state met"})