ferruzzi commented on code in PR #38287:
URL: https://github.com/apache/airflow/pull/38287#discussion_r1608627264


##########
airflow/providers/amazon/aws/hooks/neptune.py:
##########
@@ -83,3 +89,32 @@ def get_cluster_status(self, cluster_id: str) -> str:
         :return: The status of the cluster.
         """
         return 
self.get_conn().describe_db_clusters(DBClusterIdentifier=cluster_id)["DBClusters"][0]["Status"]
+
+    def get_db_instance_status(self, instance_id: str) -> str:
+        """
+        Get the status of a Neptune instance.
+
+        :param instance_id: The ID of the instance to get the status of.
+        :return: The status of the instance.
+        """
+        return 
self.get_conn().describe_db_instances(DBInstanceIdentifier=instance_id)["DBInstances"][0][

Review Comment:
   Small nit, I won't block for it since it was done this way above on L91, but 
we should be using `self.conn` (cached property in AwsGenericHook)  instead of 
`self.get_conn()`.  



##########
tests/providers/amazon/aws/operators/test_neptune.py:
##########
@@ -100,11 +102,113 @@ def test_start_cluster_cluster_available(self, 
mock_waiter, mock_get_cluster_sta
         mock_waiter.assert_not_called()
         assert resp == {"db_cluster_id": CLUSTER_ID}
 
+    @mock.patch.object(NeptuneHook, "conn")
+    def test_start_cluster_deferrable(self, mock_conn):
+        operator = NeptuneStartDbClusterOperator(
+            task_id="task_test",
+            db_cluster_id=CLUSTER_ID,
+            deferrable=True,
+            wait_for_completion=False,
+            aws_conn_id="aws_default",
+        )
+
+        with pytest.raises(TaskDeferred):

Review Comment:
   Here and elsewhere:  Do we want to use `match` here to also assert that we 
get the expected message and not just the type?
   
   ```
   with pytest.raises(TaskDeferred, match="expected message here"):
      ....
   ```



##########
airflow/providers/amazon/aws/operators/neptune.py:
##########
@@ -32,6 +36,49 @@
     from airflow.utils.context import Context
 
 
+def handle_waitable_exception(
+    operator: NeptuneStartDbClusterOperator | NeptuneStopDbClusterOperator, 
err: str
+):
+    """Handle client exceptions for invalid cluster or invalid instance status 
that are temporary.
+
+    After status change, its possible to retry. Waiter will handle terminal 
status.

Review Comment:
   ```suggestion
       After status change, it's possible to retry. Waiter will handle terminal 
status.
   ```



##########
airflow/providers/amazon/aws/operators/neptune.py:
##########
@@ -32,6 +36,49 @@
     from airflow.utils.context import Context
 
 
+def handle_waitable_exception(
+    operator: NeptuneStartDbClusterOperator | NeptuneStopDbClusterOperator, 
err: str
+):
+    """Handle client exceptions for invalid cluster or invalid instance status 
that are temporary.

Review Comment:
   ```suggestion
       """
       Handle client exceptions for invalid cluster or invalid instance status 
that are temporary.
   ```
   
   Non-blocking:   I forget which rule it is offhand, but one of the `ruff` 
rules we are trying to turn on is that multi-line docstrings like this start on 
a new line.  Someone will have to come back and change this later, may as well 
do it now if you don't mind.



##########
tests/providers/amazon/aws/operators/test_neptune.py:
##########
@@ -150,3 +257,114 @@ def test_stop_cluster_cluster_stopped(self, mock_waiter, 
mock_get_cluster_status
         mock_conn.stop_db_cluster.assert_not_called()
         mock_waiter.assert_not_called()
         assert resp == {"db_cluster_id": CLUSTER_ID}
+
+    @mock.patch.object(NeptuneHook, "conn")
+    @mock.patch.object(NeptuneHook, "get_cluster_status")
+    @mock.patch.object(NeptuneHook, "get_waiter")
+    def test_stop_cluster_cluster_error(self, mock_waiter, 
mock_get_cluster_status, mock_conn):
+        mock_get_cluster_status.return_value = "migration-failed"
+        operator = NeptuneStopDbClusterOperator(
+            task_id="task_test",
+            db_cluster_id=CLUSTER_ID,
+            deferrable=False,
+            wait_for_completion=True,
+            aws_conn_id="aws_default",
+        )
+
+        with pytest.raises(AirflowException):
+            operator.execute(None)
+
+    @mock.patch.object(NeptuneHook, "conn")
+    @mock.patch.object(NeptuneHook, "get_cluster_status")
+    @mock.patch.object(NeptuneHook, "get_waiter")
+    def test_stop_cluster_not_in_available(self, mock_waiter, 
mock_get_cluster_status, mock_conn):
+        mock_get_cluster_status.return_value = "backing-up"
+        operator = NeptuneStopDbClusterOperator(
+            task_id="task_test",
+            db_cluster_id=CLUSTER_ID,
+            deferrable=False,
+            wait_for_completion=True,
+            aws_conn_id="aws_default",
+        )
+
+        operator.execute(None)
+        mock_waiter.assert_called_with("cluster_stopped")
+
+    
@mock.patch("airflow.providers.amazon.aws.operators.neptune.NeptuneStopDbClusterOperator.defer")

Review Comment:
   Is there a reason to use mock.patch here instead of mock.patch.object like 
you used on the next line?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to