dashton90 commented on code in PR #35770:
URL: https://github.com/apache/airflow/pull/35770#discussion_r1401385011
##########
airflow/providers/amazon/aws/operators/ec2.py:
##########
@@ -254,3 +256,100 @@ def execute(self, context: Context):
"MaxAttempts": self.max_attempts,
},
)
+
+class EC2RebootInstanceOperator(BaseOperator):
+ """
+ Reboot AWS EC2 instance using boto3.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:EC2RebootInstanceOperator`
+
+ :param instance_id: id of the AWS EC2 instance
+ :param aws_conn_id: aws connection to use
+ :param region_name: (optional) aws region name associated with the client
+ :param check_interval: time in seconds that the job should wait in
+ between each instance state checks until operation is completed
+ """
+
+ template_fields: Sequence[str] = ("instance_id", "region_name")
+ ui_color = "#eeaa11"
+ ui_fgcolor = "#ffffff"
+
+ def __init__(
+ self,
+ *,
+ instance_id: str,
+ aws_conn_id: str = "aws_default",
+ region_name: str | None = None,
+ check_interval: float = 15,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.instance_id = instance_id
+ self.aws_conn_id = aws_conn_id
+ self.region_name = region_name
+ self.check_interval = check_interval
+
+ def execute(self, context: Context):
+ ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name)
+ self.log.info("Rebooting EC2 instance %s", self.instance_id)
+ instance = ec2_hook.get_instance(instance_id=self.instance_id)
+ instance.reboot()
+ ec2_hook.wait_for_state(
+ instance_id=self.instance_id,
+ target_state="running",
+ check_interval=self.check_interval,
+ )
+
+class EC2HibernateInstanceOperator(BaseOperator):
+ """
+ Hibernate AWS EC2 instance using boto3.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:EC2HibernateInstanceOperator`
+
+ :param instance_id: id of the AWS EC2 instance
+ :param aws_conn_id: aws connection to use
+ :param region_name: (optional) aws region name associated with the client
+ :param check_interval: time in seconds that the job should wait in
+ between each instance state checks until operation is completed
+ """
+
+ template_fields: Sequence[str] = ("instance_id", "region_name")
+ ui_color = "#eeaa11"
+ ui_fgcolor = "#ffffff"
+
+ def __init__(
+ self,
+ *,
+ instance_id: str,
+ aws_conn_id: str = "aws_default",
+ region_name: str | None = None,
+ check_interval: float = 15,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.instance_id = instance_id
+ self.aws_conn_id = aws_conn_id
+ self.region_name = region_name
+ self.check_interval = check_interval
+
+ def execute(self, context: Context):
+ ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name)
+ self.log.info("Hibernating EC2 instance %s", self.instance_id)
+ instance = ec2_hook.get_instance(instance_id=self.instance_id)
+
+ hibernation_options = instance.hibernation_options
+ if not hibernation_options or not hibernation_options["Configured"]:
+ raise EC2HibernationError(f"Instance {self.instance_id} is not
configured for hibernation")
Review Comment:
Not at all opposed to a generic exception. Replaced with AirflowException
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]