o-nikolas commented on code in PR #38287:
URL: https://github.com/apache/airflow/pull/38287#discussion_r1532511061


##########
airflow/providers/amazon/aws/operators/neptune.py:
##########
@@ -174,17 +253,91 @@ def __init__(
         self.delay = waiter_delay
         self.max_attempts = waiter_max_attempts
 
-    def execute(self, context: Context) -> dict[str, str]:
+    def execute(self, context: Context, event: dict[str, Any] | None = None, 
**kwargs) -> dict[str, str]:
         self.log.info("Stopping Neptune cluster: %s", self.cluster_id)
 
-        # Check to make sure the cluster is not already stopped.
+        if event:
+            # returning from a previous defer, need to restore properties
+            self.cluster_id = kwargs.get("cluster_id", self.cluster_id)
+            self.deferrable = kwargs.get("defer", self.deferrable)
+            self.delay = kwargs.get("waiter_delay", self.delay)
+            self.max_attempts = kwargs.get("waiter_max_attempts", 
self.max_attempts)
+            self.wait_for_completion = kwargs.get("wait_for_completion", 
self.wait_for_completion)
+            self.aws_conn_id = kwargs.get("aws_conn_id", self.aws_conn_id)
+            self.log.info("Restored properties from deferral")
+
+        # Check to make sure the cluster is not already stopped or that its 
not in a bad state
         status = self.hook.get_cluster_status(self.cluster_id)
+        self.log.info("Current status: %s", status)
+
         if status.lower() in NeptuneHook.STOPPED_STATES:
             self.log.info("Neptune cluster %s is already stopped.", 
self.cluster_id)
             return {"db_cluster_id": self.cluster_id}
-
-        resp = 
self.hook.conn.stop_db_cluster(DBClusterIdentifier=self.cluster_id)
-        status = resp.get("DBClusters", {}).get("Status", "Unknown")
+        elif status.lower() in NeptuneHook.ERROR_STATES:
+            # some states will not allow you to stop the cluster
+            self.log.error(
+                "Neptune cluster %s is in error state %s and cannot be 
stopped", self.cluster_id, status
+            )
+            raise AirflowException(f"Neptune cluster {self.cluster_id} is in 
error state {status}")
+
+        """
+        A cluster and its instances must be in a valid state to send the stop 
request.
+        This loop covers the case where the cluster is not available and also 
the case where
+        the cluster is available, but one or more of the instances are in an 
invalid state.
+        If either are in an invalid state, wait for the availability and retry.
+        Let the waiters handle retries and detecting the error states.
+        """
+
+        try:
+            self.hook.conn.stop_db_cluster(DBClusterIdentifier=self.cluster_id)
+
+        # cluster must be in available state to stop it
+        except ClientError as ex:
+            code = ex.response["Error"]["Code"]
+            self.log.warning("Received client error when attempting to stop 
the cluster: %s", code)
+
+            if code in ["InvalidDBInstanceStateFault", 
"InvalidClusterStateFault"]:
+                if self.deferrable:
+                    # save the arguments to restore after defer
+                    defer_args = {
+                        "cluster_id": self.cluster_id,
+                        "defer": self.deferrable,
+                        "wait_for_completion": self.wait_for_completion,
+                        "waiter_delay": self.delay,
+                        "waiter_max_attempts": self.max_attempts,
+                        "aws_conn_id": self.aws_conn_id,
+                    }
+                    if code == "InvalidDBInstanceStateFault":
+                        # wait for all instances to become available
+                        self.log.info("Deferring for instances to become 
available: %s", self.cluster_id)
+                        self.defer(
+                            trigger=NeptuneClusterInstancesAvailableTrigger(
+                                aws_conn_id=self.aws_conn_id,
+                                db_cluster_id=self.cluster_id,
+                            ),
+                            method_name="execute",
+                            kwargs=defer_args,
+                        )
+                    elif code == "InvalidClusterStateFault":
+                        self.log.info("Deferring for cluster to become 
available: %s", self.cluster_id)
+                        self.defer(
+                            trigger=NeptuneClusterAvailableTrigger(
+                                aws_conn_id=self.aws_conn_id,
+                                db_cluster_id=self.cluster_id,
+                            ),
+                            method_name="execute",
+                            kwargs=defer_args,
+                        )
+
+                else:
+                    self.log.info("Need to wait for cluster to become 
available: %s", self.cluster_id)
+                    self.hook.wait_for_cluster_availability(self.cluster_id)
+                    # make sure individual instances are available too.
+                    self.log.info("Need to wait for instances to become 
available: %s", self.cluster_id)
+                    
self.hook.wait_for_cluster_instance_availability(cluster_id=self.cluster_id)
+            else:
+                # re raise for any other type of client error
+                raise ex

Review Comment:
   ```suggestion
                   raise
   ```



##########
airflow/providers/amazon/aws/operators/neptune.py:
##########
@@ -174,17 +253,91 @@ def __init__(
         self.delay = waiter_delay
         self.max_attempts = waiter_max_attempts
 
-    def execute(self, context: Context) -> dict[str, str]:
+    def execute(self, context: Context, event: dict[str, Any] | None = None, 
**kwargs) -> dict[str, str]:
         self.log.info("Stopping Neptune cluster: %s", self.cluster_id)
 
-        # Check to make sure the cluster is not already stopped.
+        if event:
+            # returning from a previous defer, need to restore properties
+            self.cluster_id = kwargs.get("cluster_id", self.cluster_id)
+            self.deferrable = kwargs.get("defer", self.deferrable)
+            self.delay = kwargs.get("waiter_delay", self.delay)
+            self.max_attempts = kwargs.get("waiter_max_attempts", 
self.max_attempts)
+            self.wait_for_completion = kwargs.get("wait_for_completion", 
self.wait_for_completion)
+            self.aws_conn_id = kwargs.get("aws_conn_id", self.aws_conn_id)
+            self.log.info("Restored properties from deferral")
+
+        # Check to make sure the cluster is not already stopped or that its 
not in a bad state
         status = self.hook.get_cluster_status(self.cluster_id)
+        self.log.info("Current status: %s", status)
+
         if status.lower() in NeptuneHook.STOPPED_STATES:
             self.log.info("Neptune cluster %s is already stopped.", 
self.cluster_id)
             return {"db_cluster_id": self.cluster_id}
-
-        resp = 
self.hook.conn.stop_db_cluster(DBClusterIdentifier=self.cluster_id)
-        status = resp.get("DBClusters", {}).get("Status", "Unknown")
+        elif status.lower() in NeptuneHook.ERROR_STATES:
+            # some states will not allow you to stop the cluster
+            self.log.error(
+                "Neptune cluster %s is in error state %s and cannot be 
stopped", self.cluster_id, status
+            )
+            raise AirflowException(f"Neptune cluster {self.cluster_id} is in 
error state {status}")
+
+        """
+        A cluster and its instances must be in a valid state to send the stop 
request.
+        This loop covers the case where the cluster is not available and also 
the case where
+        the cluster is available, but one or more of the instances are in an 
invalid state.
+        If either are in an invalid state, wait for the availability and retry.
+        Let the waiters handle retries and detecting the error states.
+        """
+
+        try:
+            self.hook.conn.stop_db_cluster(DBClusterIdentifier=self.cluster_id)
+
+        # cluster must be in available state to stop it
+        except ClientError as ex:
+            code = ex.response["Error"]["Code"]
+            self.log.warning("Received client error when attempting to stop 
the cluster: %s", code)
+
+            if code in ["InvalidDBInstanceStateFault", 
"InvalidClusterStateFault"]:

Review Comment:
   This entire if block looks duplicated from the start, or at least very 
close. What do you think about abstracting this out into a helper method?



##########
tests/providers/amazon/aws/triggers/test_neptune.py:
##########
@@ -80,3 +81,32 @@ async def test_run_success(self, mock_async_conn, 
mock_get_waiter):
 
         assert resp == TriggerEvent({"status": "success", "db_cluster_id": 
CLUSTER_ID})
         assert mock_get_waiter().wait.call_count == 1
+
+
+class TestNeptuneClusterInstancesAvailableTrigger:
+    def test_serialization(self):
+        """
+        Asserts that the TaskStateTrigger correctly serializes its arguments
+        and classpath.
+        """
+        trigger = 
NeptuneClusterInstancesAvailableTrigger(db_cluster_id=CLUSTER_ID)
+        classpath, kwargs = trigger.serialize()
+        assert (
+            classpath
+            == 
"airflow.providers.amazon.aws.triggers.neptune.NeptuneClusterInstancesAvailableTrigger"
+        )
+        assert "db_cluster_id" in kwargs
+        assert kwargs["db_cluster_id"] == CLUSTER_ID
+
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.providers.amazon.aws.hooks.neptune.NeptuneHook.get_waiter")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.neptune.NeptuneHook.async_conn")
+    async def test_run_success(self, mock_async_conn, mock_get_waiter):

Review Comment:
   Test for the failure case as well?



##########
airflow/providers/amazon/aws/operators/neptune.py:
##########
@@ -81,17 +85,92 @@ def __init__(
         self.delay = waiter_delay
         self.max_attempts = waiter_max_attempts
 
-    def execute(self, context: Context) -> dict[str, str]:
+    def execute(self, context: Context, event: dict[str, Any] | None = None, 
**kwargs) -> dict[str, str]:
         self.log.info("Starting Neptune cluster: %s", self.cluster_id)
 
+        if event:
+            # returning from a previous defer, need to restore properties
+            self.cluster_id = kwargs.get("cluster_id", self.cluster_id)
+            self.deferrable = kwargs.get("defer", self.deferrable)
+            self.delay = kwargs.get("waiter_delay", self.delay)
+            self.max_attempts = kwargs.get("waiter_max_attempts", 
self.max_attempts)
+            self.wait_for_completion = kwargs.get("wait_for_completion", 
self.wait_for_completion)
+            self.aws_conn_id = kwargs.get("aws_conn_id", self.aws_conn_id)
+            self.log.info("Restored properties from deferral")
+
         # Check to make sure the cluster is not already available.
         status = self.hook.get_cluster_status(self.cluster_id)
         if status.lower() in NeptuneHook.AVAILABLE_STATES:
             self.log.info("Neptune cluster %s is already available.", 
self.cluster_id)
             return {"db_cluster_id": self.cluster_id}
-
-        resp = 
self.hook.conn.start_db_cluster(DBClusterIdentifier=self.cluster_id)
-        status = resp.get("DBClusters", {}).get("Status", "Unknown")
+        elif status.lower() in NeptuneHook.ERROR_STATES:
+            # some states will not allow you to start the cluster
+            self.log.error(
+                "Neptune cluster %s is in error state %s and cannot be 
started", self.cluster_id, status
+            )
+            raise AirflowException(f"Neptune cluster {self.cluster_id} is in 
error state {status}")
+
+        """
+        A cluster and its instances must be in a valid state to send the start 
request.
+        This loop covers the case where the cluster is not available and also 
the case where
+        the cluster is available, but one or more of the instances are in an 
invalid state.
+        If either are in an invalid state, wait for the availability and retry.
+        Let the waiters handle retries and detecting the error states.
+        """
+        try:
+            
self.hook.conn.start_db_cluster(DBClusterIdentifier=self.cluster_id)
+        except ClientError as ex:
+            code = ex.response["Error"]["Code"]
+            self.log.warning("Received client error when attempting to start 
the cluster: %s", code)
+
+            if code in ["InvalidDBInstanceStateFault", 
"InvalidClusterStateFault"]:
+                if self.deferrable:
+                    # save the arguments to restore after defer
+                    defer_args = {
+                        "cluster_id": self.cluster_id,
+                        "defer": self.deferrable,

Review Comment:
   Why do we need this? If we're coming back from a defer we know that 
deferrable was true?



##########
airflow/providers/amazon/aws/triggers/neptune.py:
##########
@@ -113,3 +113,48 @@ def hook(self) -> AwsGenericHook:
             verify=self.verify,
             config=self.botocore_config,
         )
+
+
+class NeptuneClusterInstancesAvailableTrigger(AwsBaseWaiterTrigger):
+    """
+    Triggers when a Neptune Cluster Instances available.

Review Comment:
   ```suggestion
       Triggers when a Neptune Cluster Instance is available.
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to