ashb commented on a change in pull request #7731:
URL: https://github.com/apache/airflow/pull/7731#discussion_r412012777



##########
File path: airflow/providers/amazon/aws/operators/ec2_start_instance.py
##########
@@ -0,0 +1,72 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from typing import Optional
+
+from airflow.models import BaseOperator
+from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
+from airflow.utils.decorators import apply_defaults
+
+
+class EC2StartInstanceOperator(BaseOperator):
+    """
+    Start AWS EC2 instance using boto3.
+
+    :param instance_id: id of the AWS EC2 instance
+    :type instance_id: str
+    :param aws_conn_id: aws connection to use
+    :type aws_conn_id: str
+    :param region_name: (optional) aws region name associated with the client
+    :type region_name: Optional[str]
+    :param check_interval: time in seconds that the job should wait in
+        between each instance state checks until operation is completed
+    :type check_interval: float
+    """
+
+    template_fields = ["region_name"]

Review comment:
       ```suggestion
       template_fields = ("instance_id", "region_name")
   ```

##########
File path: airflow/providers/amazon/aws/operators/ec2_stop_instance.py
##########
@@ -0,0 +1,72 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from typing import Optional
+
+from airflow.models import BaseOperator
+from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
+from airflow.utils.decorators import apply_defaults
+
+
+class EC2StopInstanceOperator(BaseOperator):
+    """
+    Stop AWS EC2 instance using boto3.
+
+    :param instance_id: id of the AWS EC2 instance
+    :type instance_id: str
+    :param aws_conn_id: aws connection to use
+    :type aws_conn_id: str
+    :param region_name: (optional) aws region name associated with the client
+    :type region_name: Optional[str]
+    :param check_interval: time in seconds that the job should wait in
+        between each instance state checks until operation is completed
+    :type check_interval: float
+    """
+
+    template_fields = ["region_name"]

Review comment:
       ```suggestion
       template_fields = ("instance_id", "region_name")
   ```

##########
File path: airflow/providers/amazon/aws/sensors/ec2_instance_state.py
##########
@@ -0,0 +1,70 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from typing import Optional
+
+from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
+from airflow.sensors.base_sensor_operator import BaseSensorOperator
+from airflow.utils.decorators import apply_defaults
+
+
+class EC2InstanceStateSensor(BaseSensorOperator):
+    """
+    Check the state of the AWS EC2 instance until
+    state of the instance become equal to the target state.
+
+    :param target_state: target state of instance
+    :type target_state: str
+    :param instance_id: id of the AWS EC2 instance
+    :type instance_id: str
+    :param region_name: (optional) aws region name associated with the client
+    :type region_name: Optional[str]
+    """
+
+    template_fields = ["target_state", "region_name"]

Review comment:
       ```suggestion
       template_fields = ("instance_id", "target_state", "region_name")
   ```

##########
File path: tests/providers/amazon/aws/operators/test_ec2_stop_instance.py
##########
@@ -0,0 +1,68 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import unittest
+
+from moto import mock_ec2
+
+from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
+from airflow.providers.amazon.aws.operators.ec2_stop_instance import 
EC2StopInstanceOperator
+
+
+class TestEC2Operator(unittest.TestCase):
+
+    def test_init(self):
+        ec2_operator = EC2StopInstanceOperator(
+            task_id="task_test",
+            instance_id="i-123abc",
+            aws_conn_id="aws_conn_test",
+            region_name="region-test",
+            check_interval=3,
+        )
+        self.assertEqual(ec2_operator.task_id, "task_test")
+        self.assertEqual(ec2_operator.instance_id, "i-123abc")
+        self.assertEqual(ec2_operator.aws_conn_id, "aws_conn_test")
+        self.assertEqual(ec2_operator.region_name, "region-test")
+        self.assertEqual(ec2_operator.check_interval, 3)
+
+    @mock_ec2
+    def test_stop_instance(self):
+        # create instance
+        ec2_hook = EC2Hook()
+        instances = ec2_hook.conn.create_instances(
+            MaxCount=1,
+            MinCount=1,
+        )
+        instance_id = instances[0].instance_id
+
+        # stop instance
+        stop_test = EC2StopInstanceOperator(
+            task_id="stop_test",
+            instance_id=instance_id,
+        )
+        stop_test.execute(None)
+        # assert instance state is running
+        self.assertEqual(
+            ec2_hook.get_instance_state(instance_id=instance_id),
+            "stopped"
+        )
+
+
+if __name__ == '__main__':
+    unittest.main()

Review comment:
       ```suggestion
   ```

##########
File path: tests/providers/amazon/aws/sensors/test_ec2_instance_state.py
##########
@@ -0,0 +1,132 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import unittest
+
+from moto import mock_ec2
+
+from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
+from airflow.providers.amazon.aws.sensors.ec2_instance_state import 
EC2InstanceStateSensor
+
+
+class TestEC2InstanceStateSensor(unittest.TestCase):
+
+    def test_init(self):
+        ec2_operator = EC2InstanceStateSensor(
+            task_id="task_test",
+            target_state="stopped",
+            instance_id="i-123abc",
+            aws_conn_id="aws_conn_test",
+            region_name="region-test",
+        )
+        self.assertEqual(ec2_operator.task_id, "task_test")
+        self.assertEqual(ec2_operator.target_state, "stopped")
+        self.assertEqual(ec2_operator.instance_id, "i-123abc")
+        self.assertEqual(ec2_operator.aws_conn_id, "aws_conn_test")
+        self.assertEqual(ec2_operator.region_name, "region-test")
+
+    def test_init_invalid_target_state(self):
+        invalid_target_state = "target_state_test"
+        with self.assertRaises(ValueError) as cm:
+            EC2InstanceStateSensor(
+                task_id="task_test",
+                target_state=invalid_target_state,
+                instance_id="i-123abc",
+            )
+        msg = f"Invalid target_state: {invalid_target_state}"
+        self.assertEqual(str(cm.exception), msg)
+
+    @mock_ec2
+    def test_running(self):
+        # create instance
+        ec2_hook = EC2Hook()
+        instances = ec2_hook.conn.create_instances(
+            MaxCount=1,
+            MinCount=1,
+        )
+        instance_id = instances[0].instance_id
+        # stop instance
+        ec2_hook.get_instance(instance_id=instance_id).stop()
+
+        # start sensor, waits until ec2 instance state became running
+        start_sensor = EC2InstanceStateSensor(
+            task_id="start_sensor",
+            target_state="running",
+            instance_id=instance_id,
+        )
+        # assert instance state is not running
+        self.assertFalse(start_sensor.poke(None))
+        # start instance
+        ec2_hook.get_instance(instance_id=instance_id).start()
+        # assert instance state is running
+        self.assertTrue(start_sensor.poke(None))
+
+    @mock_ec2
+    def test_stopped(self):
+        # create instance
+        ec2_hook = EC2Hook()
+        instances = ec2_hook.conn.create_instances(
+            MaxCount=1,
+            MinCount=1,
+        )
+        instance_id = instances[0].instance_id
+        # start instance
+        ec2_hook.get_instance(instance_id=instance_id).start()
+
+        # stop sensor, waits until ec2 instance state became stopped
+        stop_sensor = EC2InstanceStateSensor(
+            task_id="stop_sensor",
+            target_state="stopped",
+            instance_id=instance_id,
+        )
+        # assert instance state is not stopped
+        self.assertFalse(stop_sensor.poke(None))
+        # stop instance
+        ec2_hook.get_instance(instance_id=instance_id).stop()
+        # assert instance state is stopped
+        self.assertTrue(stop_sensor.poke(None))
+
+    @mock_ec2
+    def test_terminated(self):
+        # create instance
+        ec2_hook = EC2Hook()
+        instances = ec2_hook.conn.create_instances(
+            MaxCount=1,
+            MinCount=1,
+        )
+        instance_id = instances[0].instance_id
+        # start instance
+        ec2_hook.get_instance(instance_id=instance_id).start()
+
+        # stop sensor, waits until ec2 instance state became terminated
+        stop_sensor = EC2InstanceStateSensor(
+            task_id="stop_sensor",
+            target_state="terminated",
+            instance_id=instance_id,
+        )
+        # assert instance state is not terminated
+        self.assertFalse(stop_sensor.poke(None))
+        # stop instance
+        ec2_hook.get_instance(instance_id=instance_id).terminate()
+        # assert instance state is terminated
+        self.assertTrue(stop_sensor.poke(None))
+
+
+if __name__ == '__main__':
+    unittest.main()

Review comment:
       ```suggestion
   ```

##########
File path: tests/providers/amazon/aws/hooks/test_ec2.py
##########
@@ -0,0 +1,75 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import unittest
+
+from moto import mock_ec2
+
+from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
+
+
+class TestEC2Hook(unittest.TestCase):
+
+    def test_init(self):
+        ec2_hook = EC2Hook(
+            aws_conn_id="aws_conn_test",
+            region_name="region-test",
+        )
+        self.assertEqual(ec2_hook.aws_conn_id, "aws_conn_test")
+        self.assertEqual(ec2_hook.region_name, "region-test")
+
+    @mock_ec2
+    def test_get_conn_returns_boto3_resource(self):
+        ec2_hook = EC2Hook()
+        instances = list(ec2_hook.conn.instances.all())
+        self.assertIsNotNone(instances)
+
+    @mock_ec2
+    def test_get_instance(self):
+        ec2_hook = EC2Hook()
+        created_instances = ec2_hook.conn.create_instances(
+            MaxCount=1,
+            MinCount=1,
+        )
+        created_instance_id = created_instances[0].instance_id
+        # test get_instance method
+        existing_instance = ec2_hook.get_instance(
+            instance_id=created_instance_id
+        )
+        self.assertEqual(created_instance_id, existing_instance.instance_id)
+
+    @mock_ec2
+    def test_get_instance_state(self):
+        ec2_hook = EC2Hook()
+        created_instances = ec2_hook.conn.create_instances(
+            MaxCount=1,
+            MinCount=1,
+        )
+        created_instance_id = created_instances[0].instance_id
+        all_instances = list(ec2_hook.conn.instances.all())
+        created_instance_state = all_instances[0].state["Name"]
+        # test get_instance_state method
+        existing_instance_state = ec2_hook.get_instance_state(
+            instance_id=created_instance_id
+        )
+        self.assertEqual(created_instance_state, existing_instance_state)
+
+
+if __name__ == '__main__':
+    unittest.main()

Review comment:
       ```suggestion
   ```

##########
File path: tests/providers/amazon/aws/operators/test_ec2_start_instance.py
##########
@@ -0,0 +1,68 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import unittest
+
+from moto import mock_ec2
+
+from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
+from airflow.providers.amazon.aws.operators.ec2_start_instance import 
EC2StartInstanceOperator
+
+
+class TestEC2Operator(unittest.TestCase):
+
+    def test_init(self):
+        ec2_operator = EC2StartInstanceOperator(
+            task_id="task_test",
+            instance_id="i-123abc",
+            aws_conn_id="aws_conn_test",
+            region_name="region-test",
+            check_interval=3,
+        )
+        self.assertEqual(ec2_operator.task_id, "task_test")
+        self.assertEqual(ec2_operator.instance_id, "i-123abc")
+        self.assertEqual(ec2_operator.aws_conn_id, "aws_conn_test")
+        self.assertEqual(ec2_operator.region_name, "region-test")
+        self.assertEqual(ec2_operator.check_interval, 3)
+
+    @mock_ec2
+    def test_start_instance(self):
+        # create instance
+        ec2_hook = EC2Hook()
+        instances = ec2_hook.conn.create_instances(
+            MaxCount=1,
+            MinCount=1,
+        )
+        instance_id = instances[0].instance_id
+
+        # start instance
+        start_test = EC2StartInstanceOperator(
+            task_id="start_test",
+            instance_id=instance_id,
+        )
+        start_test.execute(None)
+        # assert instance state is running
+        self.assertEqual(
+            ec2_hook.get_instance_state(instance_id=instance_id),
+            "running"
+        )
+
+
+if __name__ == '__main__':
+    unittest.main()

Review comment:
       ```suggestion
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to