o-nikolas commented on code in PR #35770:
URL: https://github.com/apache/airflow/pull/35770#discussion_r1400955844


##########
airflow/providers/amazon/aws/operators/ec2.py:
##########
@@ -254,3 +256,100 @@ def execute(self, context: Context):
                         "MaxAttempts": self.max_attempts,
                     },
                 )
+
+class EC2RebootInstanceOperator(BaseOperator):
+    """
+    Reboot AWS EC2 instance using boto3.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:EC2RebootInstanceOperator`
+
+    :param instance_id: id of the AWS EC2 instance
+    :param aws_conn_id: aws connection to use
+    :param region_name: (optional) aws region name associated with the client

Review Comment:
   Why bother marking just this one as optional but not the arg above and 
below? I'd either do it for all optional args or don't bother. Same for 
Hibernate operator below



##########
tests/providers/amazon/aws/operators/test_ec2.py:
##########
@@ -205,3 +210,102 @@ def test_stop_instance(self):
         stop_test.execute(None)
         # assert instance state is running
         assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == 
"stopped"
+
+class TestEC2HibernateInstanceOperator(BaseEc2TestClass):
+    def test_init(self):
+        ec2_operator = EC2HibernateInstanceOperator(
+            task_id="task_test",
+            instance_id="i-123abc",
+            aws_conn_id="aws_conn_test",
+            region_name="region-test",
+            check_interval=3,
+        )
+        assert ec2_operator.task_id == "task_test"
+        assert ec2_operator.instance_id == "i-123abc"
+        assert ec2_operator.aws_conn_id == "aws_conn_test"
+        assert ec2_operator.region_name == "region-test"
+        assert ec2_operator.check_interval == 3
+
+    @mock_ec2
+    def test_hibernate_instance(self):
+        # create instance
+        ec2_hook = EC2Hook()
+        create_instance = EC2CreateInstanceOperator(
+            image_id=self._get_image_id(ec2_hook),
+            task_id="test_create_instance",
+            config={
+                "HibernationOptions": {
+                    "Configured": True
+                }
+            },
+        )
+        instance_id = create_instance.execute(None)
+
+        # hibernate instance
+        hibernate_test = EC2HibernateInstanceOperator(
+            task_id="hibernate_test",
+            instance_id=instance_id[0],
+        )
+        hibernate_test.execute(None)
+        # assert instance state is running

Review Comment:
   ```suggestion
           # assert instance state is stopped
   ```



##########
airflow/providers/amazon/aws/operators/ec2.py:
##########
@@ -254,3 +256,100 @@ def execute(self, context: Context):
                         "MaxAttempts": self.max_attempts,
                     },
                 )
+
+class EC2RebootInstanceOperator(BaseOperator):
+    """
+    Reboot AWS EC2 instance using boto3.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:EC2RebootInstanceOperator`
+
+    :param instance_id: id of the AWS EC2 instance
+    :param aws_conn_id: aws connection to use
+    :param region_name: (optional) aws region name associated with the client
+    :param check_interval: time in seconds that the job should wait in
+        between each instance state checks until operation is completed
+    """
+
+    template_fields: Sequence[str] = ("instance_id", "region_name")
+    ui_color = "#eeaa11"
+    ui_fgcolor = "#ffffff"
+
+    def __init__(
+        self,
+        *,
+        instance_id: str,
+        aws_conn_id: str = "aws_default",
+        region_name: str | None = None,
+        check_interval: float = 15,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        self.instance_id = instance_id
+        self.aws_conn_id = aws_conn_id
+        self.region_name = region_name
+        self.check_interval = check_interval
+
+    def execute(self, context: Context):
+        ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, 
region_name=self.region_name)
+        self.log.info("Rebooting EC2 instance %s", self.instance_id)
+        instance = ec2_hook.get_instance(instance_id=self.instance_id)
+        instance.reboot()
+        ec2_hook.wait_for_state(
+            instance_id=self.instance_id,
+            target_state="running",
+            check_interval=self.check_interval,
+        )
+
+class EC2HibernateInstanceOperator(BaseOperator):
+    """
+    Hibernate AWS EC2 instance using boto3.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:EC2HibernateInstanceOperator`
+
+    :param instance_id: id of the AWS EC2 instance
+    :param aws_conn_id: aws connection to use
+    :param region_name: (optional) aws region name associated with the client
+    :param check_interval: time in seconds that the job should wait in
+        between each instance state checks until operation is completed
+    """
+
+    template_fields: Sequence[str] = ("instance_id", "region_name")
+    ui_color = "#eeaa11"
+    ui_fgcolor = "#ffffff"
+
+    def __init__(
+        self,
+        *,
+        instance_id: str,
+        aws_conn_id: str = "aws_default",
+        region_name: str | None = None,
+        check_interval: float = 15,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        self.instance_id = instance_id
+        self.aws_conn_id = aws_conn_id
+        self.region_name = region_name
+        self.check_interval = check_interval
+
+    def execute(self, context: Context):
+        ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, 
region_name=self.region_name)
+        self.log.info("Hibernating EC2 instance %s", self.instance_id)
+        instance = ec2_hook.get_instance(instance_id=self.instance_id)
+        
+        hibernation_options = instance.hibernation_options
+        if not hibernation_options or not hibernation_options["Configured"]:
+            raise EC2HibernationError(f"Instance {self.instance_id} is not 
configured for hibernation")

Review Comment:
   Agreed, I don't see the value in a custom exception here.



##########
tests/system/providers/amazon/aws/example_ec2.py:
##########
@@ -150,6 +153,22 @@ def parse_response(instance_ids: list):
     )
     # [END howto_operator_ec2_terminate_instance]
     terminate_instance.trigger_rule = TriggerRule.ALL_DONE
+
+    # [START howto_operator_ec2_hibernate_instance]
+    hibernate_instance = EC2HibernateInstanceOperator(
+        task_id="hibernate_instace",
+        instance_id=instance_id,
+    )
+    # [END howto_operator_ec2_hibernate_instance]
+    hibernate_instance.trigger_rule = TriggerRule.ALL_DONE
+    
+    # [START howto_operator_ec2_reboot_instance]
+    reboot_instance = EC2RebootInstanceOperator(
+        task_id="reboot_instace",
+        instance_id=instance_id,
+    )
+    # [END howto_operator_ec2_reboot_instance]
+    reboot_instance.trigger_rule = TriggerRule.ALL_DONE

Review Comment:
   The order is actually dictated by the `chain` below. The hibernate/reboot 
are happening before the terminate there. But I would still change the order of 
the tasks here to reflect the order of the chain just for completeness.



##########
airflow/providers/amazon/aws/operators/ec2.py:
##########
@@ -254,3 +256,100 @@ def execute(self, context: Context):
                         "MaxAttempts": self.max_attempts,
                     },
                 )
+
+class EC2RebootInstanceOperator(BaseOperator):
+    """
+    Reboot AWS EC2 instance using boto3.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:EC2RebootInstanceOperator`
+
+    :param instance_id: id of the AWS EC2 instance
+    :param aws_conn_id: aws connection to use
+    :param region_name: (optional) aws region name associated with the client
+    :param check_interval: time in seconds that the job should wait in
+        between each instance state checks until operation is completed
+    """
+
+    template_fields: Sequence[str] = ("instance_id", "region_name")
+    ui_color = "#eeaa11"
+    ui_fgcolor = "#ffffff"
+
+    def __init__(
+        self,
+        *,
+        instance_id: str,
+        aws_conn_id: str = "aws_default",
+        region_name: str | None = None,
+        check_interval: float = 15,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        self.instance_id = instance_id
+        self.aws_conn_id = aws_conn_id
+        self.region_name = region_name
+        self.check_interval = check_interval
+
+    def execute(self, context: Context):
+        ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, 
region_name=self.region_name)
+        self.log.info("Rebooting EC2 instance %s", self.instance_id)
+        instance = ec2_hook.get_instance(instance_id=self.instance_id)
+        instance.reboot()
+        ec2_hook.wait_for_state(

Review Comment:
   We usually wrap waits like this with a `wait_for_completion` arg for folks 
who prefer to just fire and forget. Do you mind adding that here? If you grep 
for that arg name you'll see lots of examples. Also same for the Hibernate 
operator below



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to