Lee-W commented on code in PR #32029:
URL: https://github.com/apache/airflow/pull/32029#discussion_r1239899387


##########
airflow/providers/amazon/aws/triggers/emr.py:
##########
@@ -317,3 +317,72 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
                     await asyncio.sleep(int(self.poll_interval))
 
             yield TriggerEvent({"status": "success", "job_id": self.job_id})
+
+
+class EmrStepSensorTrigger(BaseTrigger):
+    """
+    Poll for the status of EMR container until reaches terminal state.
+
+    :param job_flow_id: job_flow_id which contains the step check the state of
+    :param step_id:  step to check the state of
+    :param aws_conn_id: Reference to AWS connection id
+    :param poke_interval: polling period in seconds to check for the status
+    """
+
+    def __init__(
+        self,
+        job_flow_id: str,
+        step_id: str,
+        aws_conn_id: str = "aws_default",
+        poke_interval: int = 30,
+        **kwargs: Any,
+    ):
+        self.job_flow_id = job_flow_id
+        self.step_id = step_id
+        self.aws_conn_id = aws_conn_id
+        self.poke_interval = poke_interval
+        super().__init__(**kwargs)
+
+    @cached_property
+    def hook(self) -> EmrHook:
+        return EmrHook(self.aws_conn_id)
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        return (
+            "airflow.providers.amazon.aws.triggers.emr.EmrStepSensorTrigger",
+            {
+                "job_flow_id": self.job_flow_id,
+                "step_id": self.step_id,
+                "aws_conn_id": self.aws_conn_id,
+                "poke_interval": self.poke_interval,
+            },
+        )
+
+    async def run(self) -> AsyncIterator[TriggerEvent]:
+        async with self.hook.async_conn as client:
+            waiter = self.hook.get_waiter("job_step_wait_for_terminal", 
deferrable=True, client=client)
+            attempt = 0
+            while True:
+                attempt = attempt + 1
+                try:
+                    await waiter.wait(
+                        ClusterId=self.job_flow_id,
+                        StepId=self.step_id,
+                        WaiterConfig={
+                            "Delay": self.poke_interval,
+                            "MaxAttempts": 1,

Review Comment:
   Thanks!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to