Lee-W commented on code in PR #32029:
URL: https://github.com/apache/airflow/pull/32029#discussion_r1236264059


##########
airflow/providers/amazon/aws/triggers/emr.py:
##########
@@ -317,3 +317,75 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
                     await asyncio.sleep(int(self.poll_interval))
 
             yield TriggerEvent({"status": "success", "job_id": self.job_id})
+
+
+class EmrStepSensorTrigger(BaseTrigger):
+    """
+    Poll for the status of EMR container until reaches terminal state.
+
+    :param virtual_cluster_id: Reference Emr cluster id
+    :param job_id:  job_id to check the state
+    :param aws_conn_id: Reference to AWS connection id
+    :param poke_interval: polling period in seconds to check for the status

Review Comment:
   These parameters are inconsistent with the `__init__`. Or are they for 
somewhere else?



##########
airflow/providers/amazon/aws/sensors/emr.py:
##########
@@ -587,3 +626,26 @@ def failure_message_from_response(response: dict[str, 
Any]) -> str | None:
                 f"with message {fail_details.get('Message')} and log file 
{fail_details.get('LogFile')}"
             )
         return None
+
+    def execute(self, context: Context) -> None:
+        if not self.deferrable:
+            super().execute(context=context)
+        else:
+            timeout = self.max_attempts * self.poke_interval + 60
+            self.defer(
+                timeout=timedelta(seconds=timeout),
+                trigger=EmrStepSensorTrigger(
+                    job_flow_id=self.job_flow_id,
+                    step_id=self.step_id,
+                    target_states=self.target_states,
+                    aws_conn_id=self.aws_conn_id,
+                    poke_interval=int(self.poke_interval),
+                ),
+                method_name="execute_complete",
+            )
+
+    def execute_complete(self, context, event=None):
+        if event["status"] != "success":
+            raise AirflowException(f"Error while running job: {event}")
+        else:
+            self.log.info("Job completed.")

Review Comment:
   It seems to me that this else block might not be needed.
   ```suggestion
           if event["status"] != "success":
               raise AirflowException(f"Error while running job: {event}")
   
           self.log.info("Job completed.")
   ```



##########
airflow/providers/amazon/aws/operators/emr.py:
##########
@@ -555,6 +559,22 @@ def execute(self, context: Context) -> str | None:
             self.client_request_token,
             self.tags,
         )
+        if self.deferrable:
+            timeout = (
+                timedelta(seconds=self.max_polling_attempts * 
self.poll_interval + 60)

Review Comment:
   +1 with the constant idea. this `60` seems to appear in multiple places



##########
tests/providers/amazon/aws/sensors/test_emr_job_flow.py:
##########
@@ -276,3 +277,20 @@ def test_different_target_states(self):
             # make sure it was called with the job_flow_id
             calls = [mock.call(ClusterId="j-8989898989")]
             self.mock_emr_client.describe_cluster.assert_has_calls(calls)
+
+    
@mock.patch("airflow.providers.amazon.aws.sensors.emr.EmrJobFlowSensor.poke")
+    def test_sensor_defer(self, mock_poke):
+        sensor = EmrJobFlowSensor(
+            task_id="test_task",
+            poke_interval=0,
+            job_flow_id="j-8989898989",
+            aws_conn_id="aws_default",
+            target_states=["RUNNING", "WAITING"],
+            deferrable=True,
+        )
+        mock_poke.return_value = False
+        with pytest.raises(TaskDeferred) as exc:
+            sensor.execute(context=None)
+        assert isinstance(
+            exc.value.trigger, EmrTerminateJobFlowTrigger
+        ), "Trigger is not a EmrTerminateJobFlowTrigger "

Review Comment:
   Perhaps we can enhance the message by including the trigger in the print.
   ```suggestion
           ), f"{exc.value.trigger} is not a EmrTerminateJobFlowTrigger "
   ```



##########
tests/providers/amazon/aws/operators/test_emr_containers.py:
##########
@@ -144,6 +145,14 @@ def test_execute_with_polling_timeout(self, 
mock_check_query_status):
             assert "Final state of EMR Containers job is SUBMITTED" in 
str(ctx.value)
             assert "Max tries of poll status exceeded" in str(ctx.value)
 
+    @mock.patch.object(EmrContainerHook, "submit_job")
+    def test_operator_defer(self, mock_submit_job):
+        self.emr_container.deferrable = True
+        self.emr_container.wait_for_completion = False
+        with pytest.raises(TaskDeferred) as exc:
+            self.emr_container.execute(context=None)
+        assert isinstance(exc.value.trigger, EmrContainerTrigger), "Trigger is 
not a EmrContainerTrigger"

Review Comment:
   Perhaps we can enhance the message by including the trigger in the print.
   
   ```suggestion
           assert isinstance(exc.value.trigger, EmrContainerTrigger), 
f"{exc.value.trigger} is not a EmrContainerTrigger"
   ```



##########
airflow/providers/amazon/aws/operators/emr.py:
##########
@@ -576,6 +596,13 @@ def execute(self, context: Context) -> str | None:
 
         return self.job_id
 
+    def execute_complete(self, context, event=None):
+        if event["status"] != "success":
+            raise AirflowException(f"Error while running job: {event}")
+        else:
+            self.log.info(event["message"])
+            return event["job_id"]

Review Comment:
   It seems to me that this else block might not be needed.
   
   ```suggestion
           if event["status"] != "success":
               raise AirflowException(f"Error while running job: {event}")
   
           self.log.info(event["message"])
           return event["job_id"]
   ```



##########
airflow/providers/amazon/aws/operators/emr.py:
##########
@@ -507,6 +509,7 @@ def __init__(
         max_tries: int | None = None,
         tags: dict | None = None,
         max_polling_attempts: int | None = None,
+        deferrable: bool = False,

Review Comment:
   sounds good 👍 



##########
tests/providers/amazon/aws/sensors/test_emr_containers.py:
##########
@@ -81,6 +81,4 @@ def test_sensor_defer(self, mock_poke):
         mock_poke.return_value = False
         with pytest.raises(TaskDeferred) as exc:
             self.sensor.execute(context=None)
-        assert isinstance(
-            exc.value.trigger, EmrContainerSensorTrigger
-        ), "Trigger is not a EmrContainerSensorTrigger"
+        assert isinstance(exc.value.trigger, EmrContainerTrigger), "Trigger is 
not a EmrContainerTrigger"

Review Comment:
   Perhaps we can enhance the message by including the trigger in the print.
   
   ```suggestion
           assert isinstance(, EmrContainerTrigger), f"{exc.value.trigger} is 
not a EmrContainerTrigger"
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to