This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new ca1202fd31 Add `EC2HibernateInstanceOperator` and
`EC2RebootInstanceOperator` (#35790)
ca1202fd31 is described below
commit ca1202fd31f0ea8c25833cf11a5f7aa97c1db87b
Author: David <[email protected]>
AuthorDate: Thu Nov 23 12:58:59 2023 -0500
Add `EC2HibernateInstanceOperator` and `EC2RebootInstanceOperator` (#35790)
---
airflow/providers/amazon/aws/operators/ec2.py | 126 ++++++++++++++++
.../operators/ec2.rst | 28 ++++
tests/providers/amazon/aws/operators/test_ec2.py | 167 +++++++++++++++++++++
tests/system/providers/amazon/aws/example_ec2.py | 19 +++
4 files changed, 340 insertions(+)
diff --git a/airflow/providers/amazon/aws/operators/ec2.py
b/airflow/providers/amazon/aws/operators/ec2.py
index b9de533378..2dbb6986d7 100644
--- a/airflow/providers/amazon/aws/operators/ec2.py
+++ b/airflow/providers/amazon/aws/operators/ec2.py
@@ -19,6 +19,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Sequence
+from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
@@ -254,3 +255,128 @@ class EC2TerminateInstanceOperator(BaseOperator):
"MaxAttempts": self.max_attempts,
},
)
+
+
+class EC2RebootInstanceOperator(BaseOperator):
+ """
+ Reboot Amazon EC2 instances.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:EC2RebootInstanceOperator`
+
+ :param instance_ids: ID of the instance(s) to be rebooted.
+ :param aws_conn_id: AWS connection to use
+ :param region_name: AWS region name associated with the client.
+ :param poll_interval: Number of seconds to wait before attempting to
+ check state of instance. Only used if wait_for_completion is True.
Default is 20.
+ :param max_attempts: Maximum number of attempts when checking state of
instance.
+ Only used if wait_for_completion is True. Default is 20.
+ :param wait_for_completion: If True, the operator will wait for the
instance to be
+ in the `running` state before returning.
+ """
+
+ template_fields: Sequence[str] = ("instance_ids", "region_name")
+ ui_color = "#eeaa11"
+ ui_fgcolor = "#ffffff"
+
+ def __init__(
+ self,
+ *,
+ instance_ids: str | list[str],
+ aws_conn_id: str = "aws_default",
+ region_name: str | None = None,
+ poll_interval: int = 20,
+ max_attempts: int = 20,
+ wait_for_completion: bool = False,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.instance_ids = instance_ids
+ self.aws_conn_id = aws_conn_id
+ self.region_name = region_name
+ self.poll_interval = poll_interval
+ self.max_attempts = max_attempts
+ self.wait_for_completion = wait_for_completion
+
+ def execute(self, context: Context):
+ if isinstance(self.instance_ids, str):
+ self.instance_ids = [self.instance_ids]
+ ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name, api_type="client_type")
+ self.log.info("Rebooting EC2 instances %s", ",
".join(self.instance_ids))
+ ec2_hook.conn.reboot_instances(InstanceIds=self.instance_ids)
+
+ if self.wait_for_completion:
+ ec2_hook.get_waiter("instance_running").wait(
+ InstanceIds=self.instance_ids,
+ WaiterConfig={
+ "Delay": self.poll_interval,
+ "MaxAttempts": self.max_attempts,
+ },
+ )
+
+
+class EC2HibernateInstanceOperator(BaseOperator):
+ """
+ Hibernate Amazon EC2 instances.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:EC2HibernateInstanceOperator`
+
+ :param instance_ids: ID of the instance(s) to be hibernated.
+ :param aws_conn_id: AWS connection to use
+ :param region_name: AWS region name associated with the client.
+ :param poll_interval: Number of seconds to wait before attempting to
+ check state of instance. Only used if wait_for_completion is True.
Default is 20.
+ :param max_attempts: Maximum number of attempts when checking state of
instance.
+ Only used if wait_for_completion is True. Default is 20.
+ :param wait_for_completion: If True, the operator will wait for the
instance to be
+ in the `stopped` state before returning.
+ """
+
+ template_fields: Sequence[str] = ("instance_ids", "region_name")
+ ui_color = "#eeaa11"
+ ui_fgcolor = "#ffffff"
+
+ def __init__(
+ self,
+ *,
+ instance_ids: str | list[str],
+ aws_conn_id: str = "aws_default",
+ region_name: str | None = None,
+ poll_interval: int = 20,
+ max_attempts: int = 20,
+ wait_for_completion: bool = False,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.instance_ids = instance_ids
+ self.aws_conn_id = aws_conn_id
+ self.region_name = region_name
+ self.poll_interval = poll_interval
+ self.max_attempts = max_attempts
+ self.wait_for_completion = wait_for_completion
+
+ def execute(self, context: Context):
+ if isinstance(self.instance_ids, str):
+ self.instance_ids = [self.instance_ids]
+ ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name, api_type="client_type")
+ self.log.info("Hibernating EC2 instances %s", ",
".join(self.instance_ids))
+ instances = ec2_hook.get_instances(instance_ids=self.instance_ids)
+
+ for instance in instances:
+ hibernation_options = instance.get("HibernationOptions")
+ if not hibernation_options or not
hibernation_options["Configured"]:
+ raise AirflowException(f"Instance {instance['InstanceId']} is
not configured for hibernation")
+
+ ec2_hook.conn.stop_instances(InstanceIds=self.instance_ids,
Hibernate=True)
+
+ if self.wait_for_completion:
+ ec2_hook.get_waiter("instance_stopped").wait(
+ InstanceIds=self.instance_ids,
+ WaiterConfig={
+ "Delay": self.poll_interval,
+ "MaxAttempts": self.max_attempts,
+ },
+ )
diff --git a/docs/apache-airflow-providers-amazon/operators/ec2.rst
b/docs/apache-airflow-providers-amazon/operators/ec2.rst
index 2018d8113f..e5462b32a1 100644
--- a/docs/apache-airflow-providers-amazon/operators/ec2.rst
+++ b/docs/apache-airflow-providers-amazon/operators/ec2.rst
@@ -86,6 +86,34 @@ To terminate an Amazon EC2 instance you can use
:start-after: [START howto_operator_ec2_terminate_instance]
:end-before: [END howto_operator_ec2_terminate_instance]
+.. _howto/operator:EC2RebootInstanceOperator:
+
+Reboot an Amazon EC2 instance
+================================
+
+To reboot an Amazon EC2 instance you can use
+:class:`~airflow.providers.amazon.aws.operators.ec2.EC2RebootInstanceOperator`.
+
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_ec2.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_ec2_reboot_instance]
+ :end-before: [END howto_operator_ec2_reboot_instance]
+
+.. _howto/operator:EC2HibernateInstanceOperator:
+
+Hibernate an Amazon EC2 instance
+================================
+
+To hibernate an Amazon EC2 instance you can use
+:class:`~airflow.providers.amazon.aws.operators.ec2.EC2HibernateInstanceOperator`.
+
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_ec2.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_ec2_hibernate_instance]
+ :end-before: [END howto_operator_ec2_hibernate_instance]
+
Sensors
-------
diff --git a/tests/providers/amazon/aws/operators/test_ec2.py
b/tests/providers/amazon/aws/operators/test_ec2.py
index adf3ffeb91..b11d72b714 100644
--- a/tests/providers/amazon/aws/operators/test_ec2.py
+++ b/tests/providers/amazon/aws/operators/test_ec2.py
@@ -17,11 +17,15 @@
# under the License.
from __future__ import annotations
+import pytest
from moto import mock_ec2
+from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
from airflow.providers.amazon.aws.operators.ec2 import (
EC2CreateInstanceOperator,
+ EC2HibernateInstanceOperator,
+ EC2RebootInstanceOperator,
EC2StartInstanceOperator,
EC2StopInstanceOperator,
EC2TerminateInstanceOperator,
@@ -205,3 +209,166 @@ class TestEC2StopInstanceOperator(BaseEc2TestClass):
stop_test.execute(None)
# assert instance state is running
assert ec2_hook.get_instance_state(instance_id=instance_id[0]) ==
"stopped"
+
+
+class TestEC2HibernateInstanceOperator(BaseEc2TestClass):
+ def test_init(self):
+ ec2_operator = EC2HibernateInstanceOperator(
+ task_id="task_test",
+ instance_ids="i-123abc",
+ )
+ assert ec2_operator.task_id == "task_test"
+ assert ec2_operator.instance_ids == "i-123abc"
+
+ @mock_ec2
+ def test_hibernate_instance(self):
+ # create instance
+ ec2_hook = EC2Hook()
+ create_instance = EC2CreateInstanceOperator(
+ image_id=self._get_image_id(ec2_hook),
+ task_id="test_create_instance",
+ config={"HibernationOptions": {"Configured": True}},
+ )
+ instance_id = create_instance.execute(None)
+
+ # hibernate instance
+ hibernate_test = EC2HibernateInstanceOperator(
+ task_id="hibernate_test",
+ instance_ids=instance_id[0],
+ )
+ hibernate_test.execute(None)
+ # assert instance state is stopped
+ assert ec2_hook.get_instance_state(instance_id=instance_id[0]) ==
"stopped"
+
+ @mock_ec2
+ def test_hibernate_multiple_instances(self):
+ ec2_hook = EC2Hook()
+ create_instances = EC2CreateInstanceOperator(
+ task_id="test_create_multiple_instances",
+ image_id=self._get_image_id(hook=ec2_hook),
+ config={"HibernationOptions": {"Configured": True}},
+ min_count=5,
+ max_count=5,
+ )
+ instance_ids = create_instances.execute(None)
+ assert len(instance_ids) == 5
+
+ for id in instance_ids:
+ assert ec2_hook.get_instance_state(instance_id=id) == "running"
+
+ hibernate_instance = EC2HibernateInstanceOperator(
+ task_id="test_hibernate_instance", instance_ids=instance_ids
+ )
+ hibernate_instance.execute(None)
+ for id in instance_ids:
+ assert ec2_hook.get_instance_state(instance_id=id) == "stopped"
+
+ @mock_ec2
+ def test_cannot_hibernate_instance(self):
+ # create instance
+ ec2_hook = EC2Hook()
+ create_instance = EC2CreateInstanceOperator(
+ image_id=self._get_image_id(ec2_hook),
+ task_id="test_create_instance",
+ )
+ instance_id = create_instance.execute(None)
+
+ # hibernate instance
+ hibernate_test = EC2HibernateInstanceOperator(
+ task_id="hibernate_test",
+ instance_ids=instance_id[0],
+ )
+
+ # assert hibernating an instance not configured for hibernation raises
an error
+ with pytest.raises(
+ AirflowException,
+ match="Instance .* is not configured for hibernation",
+ ):
+ hibernate_test.execute(None)
+
+ # assert instance state is running
+ assert ec2_hook.get_instance_state(instance_id=instance_id[0]) ==
"running"
+
+ @mock_ec2
+ def test_cannot_hibernate_some_instances(self):
+ # create instance
+ ec2_hook = EC2Hook()
+ create_instance_hibernate = EC2CreateInstanceOperator(
+ image_id=self._get_image_id(ec2_hook),
+ task_id="test_create_instance",
+ config={"HibernationOptions": {"Configured": True}},
+ )
+ instance_id_hibernate = create_instance_hibernate.execute(None)
+ create_instance_cannot_hibernate = EC2CreateInstanceOperator(
+ image_id=self._get_image_id(ec2_hook),
+ task_id="test_create_instance",
+ )
+ instance_id_cannot_hibernate =
create_instance_cannot_hibernate.execute(None)
+ instance_ids = [instance_id_hibernate[0],
instance_id_cannot_hibernate[0]]
+
+ # hibernate instance
+ hibernate_test = EC2HibernateInstanceOperator(
+ task_id="hibernate_test",
+ instance_ids=instance_ids,
+ )
+ # assert hibernating an instance not configured for hibernation raises
an error
+ with pytest.raises(
+ AirflowException,
+ match="Instance .* is not configured for hibernation",
+ ):
+ hibernate_test.execute(None)
+
+ # assert instance state is running
+ for id in instance_ids:
+ assert ec2_hook.get_instance_state(instance_id=id) == "running"
+
+
+class TestEC2RebootInstanceOperator(BaseEc2TestClass):
+ def test_init(self):
+ ec2_operator = EC2RebootInstanceOperator(
+ task_id="task_test",
+ instance_ids="i-123abc",
+ )
+ assert ec2_operator.task_id == "task_test"
+ assert ec2_operator.instance_ids == "i-123abc"
+
+ @mock_ec2
+ def test_reboot_instance(self):
+ # create instance
+ ec2_hook = EC2Hook()
+ create_instance = EC2CreateInstanceOperator(
+ image_id=self._get_image_id(ec2_hook),
+ task_id="test_create_instance",
+ )
+ instance_id = create_instance.execute(None)
+
+ # reboot instance
+ reboot_test = EC2RebootInstanceOperator(
+ task_id="reboot_test",
+ instance_ids=instance_id[0],
+ )
+ reboot_test.execute(None)
+ # assert instance state is running
+ assert ec2_hook.get_instance_state(instance_id=instance_id[0]) ==
"running"
+
+ @mock_ec2
+ def test_reboot_multiple_instances(self):
+ ec2_hook = EC2Hook()
+ create_instances = EC2CreateInstanceOperator(
+ task_id="test_create_multiple_instances",
+ image_id=self._get_image_id(hook=ec2_hook),
+ min_count=5,
+ max_count=5,
+ )
+ instance_ids = create_instances.execute(None)
+ assert len(instance_ids) == 5
+
+ for id in instance_ids:
+ assert ec2_hook.get_instance_state(instance_id=id) == "running"
+
+ terminate_instance = EC2RebootInstanceOperator(
+ task_id="test_reboot_instance", instance_ids=instance_ids
+ )
+ terminate_instance.execute(None)
+ for id in instance_ids:
+ assert ec2_hook.get_instance_state(instance_id=id) == "running"
diff --git a/tests/system/providers/amazon/aws/example_ec2.py
b/tests/system/providers/amazon/aws/example_ec2.py
index 506d73908b..aeffe4ad34 100644
--- a/tests/system/providers/amazon/aws/example_ec2.py
+++ b/tests/system/providers/amazon/aws/example_ec2.py
@@ -26,6 +26,8 @@ from airflow.models.baseoperator import chain
from airflow.models.dag import DAG
from airflow.providers.amazon.aws.operators.ec2 import (
EC2CreateInstanceOperator,
+ EC2HibernateInstanceOperator,
+ EC2RebootInstanceOperator,
EC2StartInstanceOperator,
EC2StopInstanceOperator,
EC2TerminateInstanceOperator,
@@ -103,6 +105,7 @@ with DAG(
# Use IMDSv2 for greater security, see the following doc for more
details:
#
https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html
"MetadataOptions": {"HttpEndpoint": "enabled", "HttpTokens":
"required"},
+ "HibernationOptions": {"Configured": True},
}
# EC2CreateInstanceOperator creates and starts the EC2 instances. To test
the EC2StartInstanceOperator,
@@ -142,6 +145,20 @@ with DAG(
)
# [END howto_sensor_ec2_instance_state]
+ # [START howto_operator_ec2_reboot_instance]
+ reboot_instance = EC2RebootInstanceOperator(
+ task_id="reboot_instace",
+ instance_ids=instance_id,
+ )
+ # [END howto_operator_ec2_reboot_instance]
+
+ # [START howto_operator_ec2_hibernate_instance]
+ hibernate_instance = EC2HibernateInstanceOperator(
+ task_id="hibernate_instace",
+ instance_ids=instance_id,
+ )
+ # [END howto_operator_ec2_hibernate_instance]
+
# [START howto_operator_ec2_terminate_instance]
terminate_instance = EC2TerminateInstanceOperator(
task_id="terminate_instance",
@@ -161,6 +178,8 @@ with DAG(
stop_instance,
start_instance,
await_instance,
+ reboot_instance,
+ hibernate_instance,
terminate_instance,
# TEST TEARDOWN
delete_key_pair(key_name),