This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new ca1202fd31 Add `EC2HibernateInstanceOperator` and 
`EC2RebootInstanceOperator` (#35790)
ca1202fd31 is described below

commit ca1202fd31f0ea8c25833cf11a5f7aa97c1db87b
Author: David <[email protected]>
AuthorDate: Thu Nov 23 12:58:59 2023 -0500

    Add `EC2HibernateInstanceOperator` and `EC2RebootInstanceOperator` (#35790)
---
 airflow/providers/amazon/aws/operators/ec2.py      | 126 ++++++++++++++++
 .../operators/ec2.rst                              |  28 ++++
 tests/providers/amazon/aws/operators/test_ec2.py   | 167 +++++++++++++++++++++
 tests/system/providers/amazon/aws/example_ec2.py   |  19 +++
 4 files changed, 340 insertions(+)

diff --git a/airflow/providers/amazon/aws/operators/ec2.py 
b/airflow/providers/amazon/aws/operators/ec2.py
index b9de533378..2dbb6986d7 100644
--- a/airflow/providers/amazon/aws/operators/ec2.py
+++ b/airflow/providers/amazon/aws/operators/ec2.py
@@ -19,6 +19,7 @@ from __future__ import annotations
 
 from typing import TYPE_CHECKING, Sequence
 
+from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator
 from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
 
@@ -254,3 +255,128 @@ class EC2TerminateInstanceOperator(BaseOperator):
                         "MaxAttempts": self.max_attempts,
                     },
                 )
+
+
+class EC2RebootInstanceOperator(BaseOperator):
+    """
+    Reboot Amazon EC2 instances.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:EC2RebootInstanceOperator`
+
+    :param instance_ids: ID of the instance(s) to be rebooted.
+    :param aws_conn_id: AWS connection to use
+    :param region_name: AWS region name associated with the client.
+    :param poll_interval: Number of seconds to wait before attempting to
+        check state of instance. Only used if wait_for_completion is True. 
Default is 20.
+    :param max_attempts: Maximum number of attempts when checking state of 
instance.
+        Only used if wait_for_completion is True. Default is 20.
+    :param wait_for_completion: If True, the operator will wait for the 
instance to be
+        in the `running` state before returning.
+    """
+
+    template_fields: Sequence[str] = ("instance_ids", "region_name")
+    ui_color = "#eeaa11"
+    ui_fgcolor = "#ffffff"
+
+    def __init__(
+        self,
+        *,
+        instance_ids: str | list[str],
+        aws_conn_id: str = "aws_default",
+        region_name: str | None = None,
+        poll_interval: int = 20,
+        max_attempts: int = 20,
+        wait_for_completion: bool = False,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        self.instance_ids = instance_ids
+        self.aws_conn_id = aws_conn_id
+        self.region_name = region_name
+        self.poll_interval = poll_interval
+        self.max_attempts = max_attempts
+        self.wait_for_completion = wait_for_completion
+
+    def execute(self, context: Context):
+        if isinstance(self.instance_ids, str):
+            self.instance_ids = [self.instance_ids]
+        ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, 
region_name=self.region_name, api_type="client_type")
+        self.log.info("Rebooting EC2 instances %s", ", 
".join(self.instance_ids))
+        ec2_hook.conn.reboot_instances(InstanceIds=self.instance_ids)
+
+        if self.wait_for_completion:
+            ec2_hook.get_waiter("instance_running").wait(
+                InstanceIds=self.instance_ids,
+                WaiterConfig={
+                    "Delay": self.poll_interval,
+                    "MaxAttempts": self.max_attempts,
+                },
+            )
+
+
+class EC2HibernateInstanceOperator(BaseOperator):
+    """
+    Hibernate Amazon EC2 instances.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:EC2HibernateInstanceOperator`
+
+    :param instance_ids: ID of the instance(s) to be hibernated.
+    :param aws_conn_id: AWS connection to use
+    :param region_name: AWS region name associated with the client.
+    :param poll_interval: Number of seconds to wait before attempting to
+        check state of instance. Only used if wait_for_completion is True. 
Default is 20.
+    :param max_attempts: Maximum number of attempts when checking state of 
instance.
+        Only used if wait_for_completion is True. Default is 20.
+    :param wait_for_completion: If True, the operator will wait for the 
instance to be
+        in the `stopped` state before returning.
+    """
+
+    template_fields: Sequence[str] = ("instance_ids", "region_name")
+    ui_color = "#eeaa11"
+    ui_fgcolor = "#ffffff"
+
+    def __init__(
+        self,
+        *,
+        instance_ids: str | list[str],
+        aws_conn_id: str = "aws_default",
+        region_name: str | None = None,
+        poll_interval: int = 20,
+        max_attempts: int = 20,
+        wait_for_completion: bool = False,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        self.instance_ids = instance_ids
+        self.aws_conn_id = aws_conn_id
+        self.region_name = region_name
+        self.poll_interval = poll_interval
+        self.max_attempts = max_attempts
+        self.wait_for_completion = wait_for_completion
+
+    def execute(self, context: Context):
+        if isinstance(self.instance_ids, str):
+            self.instance_ids = [self.instance_ids]
+        ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, 
region_name=self.region_name, api_type="client_type")
+        self.log.info("Hibernating EC2 instances %s", ", 
".join(self.instance_ids))
+        instances = ec2_hook.get_instances(instance_ids=self.instance_ids)
+
+        for instance in instances:
+            hibernation_options = instance.get("HibernationOptions")
+            if not hibernation_options or not 
hibernation_options["Configured"]:
+                raise AirflowException(f"Instance {instance['InstanceId']} is 
not configured for hibernation")
+
+        ec2_hook.conn.stop_instances(InstanceIds=self.instance_ids, 
Hibernate=True)
+
+        if self.wait_for_completion:
+            ec2_hook.get_waiter("instance_stopped").wait(
+                InstanceIds=self.instance_ids,
+                WaiterConfig={
+                    "Delay": self.poll_interval,
+                    "MaxAttempts": self.max_attempts,
+                },
+            )
diff --git a/docs/apache-airflow-providers-amazon/operators/ec2.rst 
b/docs/apache-airflow-providers-amazon/operators/ec2.rst
index 2018d8113f..e5462b32a1 100644
--- a/docs/apache-airflow-providers-amazon/operators/ec2.rst
+++ b/docs/apache-airflow-providers-amazon/operators/ec2.rst
@@ -86,6 +86,34 @@ To terminate an Amazon EC2 instance you can use
     :start-after: [START howto_operator_ec2_terminate_instance]
     :end-before: [END howto_operator_ec2_terminate_instance]
 
+.. _howto/operator:EC2RebootInstanceOperator:
+
+Reboot an Amazon EC2 instance
+================================
+
+To reboot an Amazon EC2 instance you can use
+:class:`~airflow.providers.amazon.aws.operators.ec2.EC2RebootInstanceOperator`.
+
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_ec2.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_ec2_reboot_instance]
+    :end-before: [END howto_operator_ec2_reboot_instance]
+
+.. _howto/operator:EC2HibernateInstanceOperator:
+
+Hibernate an Amazon EC2 instance
+================================
+
+To hibernate an Amazon EC2 instance you can use
+:class:`~airflow.providers.amazon.aws.operators.ec2.EC2HibernateInstanceOperator`.
+
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_ec2.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_ec2_hibernate_instance]
+    :end-before: [END howto_operator_ec2_hibernate_instance]
+
 Sensors
 -------
 
diff --git a/tests/providers/amazon/aws/operators/test_ec2.py 
b/tests/providers/amazon/aws/operators/test_ec2.py
index adf3ffeb91..b11d72b714 100644
--- a/tests/providers/amazon/aws/operators/test_ec2.py
+++ b/tests/providers/amazon/aws/operators/test_ec2.py
@@ -17,11 +17,15 @@
 # under the License.
 from __future__ import annotations
 
+import pytest
 from moto import mock_ec2
 
+from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
 from airflow.providers.amazon.aws.operators.ec2 import (
     EC2CreateInstanceOperator,
+    EC2HibernateInstanceOperator,
+    EC2RebootInstanceOperator,
     EC2StartInstanceOperator,
     EC2StopInstanceOperator,
     EC2TerminateInstanceOperator,
@@ -205,3 +209,166 @@ class TestEC2StopInstanceOperator(BaseEc2TestClass):
         stop_test.execute(None)
         # assert instance state is running
         assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == 
"stopped"
+
+
+class TestEC2HibernateInstanceOperator(BaseEc2TestClass):
+    def test_init(self):
+        ec2_operator = EC2HibernateInstanceOperator(
+            task_id="task_test",
+            instance_ids="i-123abc",
+        )
+        assert ec2_operator.task_id == "task_test"
+        assert ec2_operator.instance_ids == "i-123abc"
+
+    @mock_ec2
+    def test_hibernate_instance(self):
+        # create instance
+        ec2_hook = EC2Hook()
+        create_instance = EC2CreateInstanceOperator(
+            image_id=self._get_image_id(ec2_hook),
+            task_id="test_create_instance",
+            config={"HibernationOptions": {"Configured": True}},
+        )
+        instance_id = create_instance.execute(None)
+
+        # hibernate instance
+        hibernate_test = EC2HibernateInstanceOperator(
+            task_id="hibernate_test",
+            instance_ids=instance_id[0],
+        )
+        hibernate_test.execute(None)
+        # assert instance state is stopped
+        assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == 
"stopped"
+
+    @mock_ec2
+    def test_hibernate_multiple_instances(self):
+        ec2_hook = EC2Hook()
+        create_instances = EC2CreateInstanceOperator(
+            task_id="test_create_multiple_instances",
+            image_id=self._get_image_id(hook=ec2_hook),
+            config={"HibernationOptions": {"Configured": True}},
+            min_count=5,
+            max_count=5,
+        )
+        instance_ids = create_instances.execute(None)
+        assert len(instance_ids) == 5
+
+        for id in instance_ids:
+            assert ec2_hook.get_instance_state(instance_id=id) == "running"
+
+        hibernate_instance = EC2HibernateInstanceOperator(
+            task_id="test_hibernate_instance", instance_ids=instance_ids
+        )
+        hibernate_instance.execute(None)
+        for id in instance_ids:
+            assert ec2_hook.get_instance_state(instance_id=id) == "stopped"
+
+    @mock_ec2
+    def test_cannot_hibernate_instance(self):
+        # create instance
+        ec2_hook = EC2Hook()
+        create_instance = EC2CreateInstanceOperator(
+            image_id=self._get_image_id(ec2_hook),
+            task_id="test_create_instance",
+        )
+        instance_id = create_instance.execute(None)
+
+        # hibernate instance
+        hibernate_test = EC2HibernateInstanceOperator(
+            task_id="hibernate_test",
+            instance_ids=instance_id[0],
+        )
+
+        # assert hibernating an instance not configured for hibernation raises 
an error
+        with pytest.raises(
+            AirflowException,
+            match="Instance .* is not configured for hibernation",
+        ):
+            hibernate_test.execute(None)
+
+        # assert instance state is running
+        assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == 
"running"
+
+    @mock_ec2
+    def test_cannot_hibernate_some_instances(self):
+        # create instance
+        ec2_hook = EC2Hook()
+        create_instance_hibernate = EC2CreateInstanceOperator(
+            image_id=self._get_image_id(ec2_hook),
+            task_id="test_create_instance",
+            config={"HibernationOptions": {"Configured": True}},
+        )
+        instance_id_hibernate = create_instance_hibernate.execute(None)
+        create_instance_cannot_hibernate = EC2CreateInstanceOperator(
+            image_id=self._get_image_id(ec2_hook),
+            task_id="test_create_instance",
+        )
+        instance_id_cannot_hibernate = 
create_instance_cannot_hibernate.execute(None)
+        instance_ids = [instance_id_hibernate[0], 
instance_id_cannot_hibernate[0]]
+
+        # hibernate instance
+        hibernate_test = EC2HibernateInstanceOperator(
+            task_id="hibernate_test",
+            instance_ids=instance_ids,
+        )
+        # assert hibernating an instance not configured for hibernation raises 
an error
+        with pytest.raises(
+            AirflowException,
+            match="Instance .* is not configured for hibernation",
+        ):
+            hibernate_test.execute(None)
+
+        # assert instance state is running
+        for id in instance_ids:
+            assert ec2_hook.get_instance_state(instance_id=id) == "running"
+
+
+class TestEC2RebootInstanceOperator(BaseEc2TestClass):
+    def test_init(self):
+        ec2_operator = EC2RebootInstanceOperator(
+            task_id="task_test",
+            instance_ids="i-123abc",
+        )
+        assert ec2_operator.task_id == "task_test"
+        assert ec2_operator.instance_ids == "i-123abc"
+
+    @mock_ec2
+    def test_reboot_instance(self):
+        # create instance
+        ec2_hook = EC2Hook()
+        create_instance = EC2CreateInstanceOperator(
+            image_id=self._get_image_id(ec2_hook),
+            task_id="test_create_instance",
+        )
+        instance_id = create_instance.execute(None)
+
+        # reboot instance
+        reboot_test = EC2RebootInstanceOperator(
+            task_id="reboot_test",
+            instance_ids=instance_id[0],
+        )
+        reboot_test.execute(None)
+        # assert instance state is running
+        assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == 
"running"
+
+    @mock_ec2
+    def test_reboot_multiple_instances(self):
+        ec2_hook = EC2Hook()
+        create_instances = EC2CreateInstanceOperator(
+            task_id="test_create_multiple_instances",
+            image_id=self._get_image_id(hook=ec2_hook),
+            min_count=5,
+            max_count=5,
+        )
+        instance_ids = create_instances.execute(None)
+        assert len(instance_ids) == 5
+
+        for id in instance_ids:
+            assert ec2_hook.get_instance_state(instance_id=id) == "running"
+
+        terminate_instance = EC2RebootInstanceOperator(
+            task_id="test_reboot_instance", instance_ids=instance_ids
+        )
+        terminate_instance.execute(None)
+        for id in instance_ids:
+            assert ec2_hook.get_instance_state(instance_id=id) == "running"
diff --git a/tests/system/providers/amazon/aws/example_ec2.py 
b/tests/system/providers/amazon/aws/example_ec2.py
index 506d73908b..aeffe4ad34 100644
--- a/tests/system/providers/amazon/aws/example_ec2.py
+++ b/tests/system/providers/amazon/aws/example_ec2.py
@@ -26,6 +26,8 @@ from airflow.models.baseoperator import chain
 from airflow.models.dag import DAG
 from airflow.providers.amazon.aws.operators.ec2 import (
     EC2CreateInstanceOperator,
+    EC2HibernateInstanceOperator,
+    EC2RebootInstanceOperator,
     EC2StartInstanceOperator,
     EC2StopInstanceOperator,
     EC2TerminateInstanceOperator,
@@ -103,6 +105,7 @@ with DAG(
         # Use IMDSv2 for greater security, see the following doc for more 
details:
         # 
https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html
         "MetadataOptions": {"HttpEndpoint": "enabled", "HttpTokens": 
"required"},
+        "HibernationOptions": {"Configured": True},
     }
 
     # EC2CreateInstanceOperator creates and starts the EC2 instances. To test 
the EC2StartInstanceOperator,
@@ -142,6 +145,20 @@ with DAG(
     )
     # [END howto_sensor_ec2_instance_state]
 
+    # [START howto_operator_ec2_reboot_instance]
+    reboot_instance = EC2RebootInstanceOperator(
+        task_id="reboot_instace",
+        instance_ids=instance_id,
+    )
+    # [END howto_operator_ec2_reboot_instance]
+
+    # [START howto_operator_ec2_hibernate_instance]
+    hibernate_instance = EC2HibernateInstanceOperator(
+        task_id="hibernate_instace",
+        instance_ids=instance_id,
+    )
+    # [END howto_operator_ec2_hibernate_instance]
+
     # [START howto_operator_ec2_terminate_instance]
     terminate_instance = EC2TerminateInstanceOperator(
         task_id="terminate_instance",
@@ -161,6 +178,8 @@ with DAG(
         stop_instance,
         start_instance,
         await_instance,
+        reboot_instance,
+        hibernate_instance,
         terminate_instance,
         # TEST TEARDOWN
         delete_key_pair(key_name),

Reply via email to