pankajastro commented on code in PR #30279:
URL: https://github.com/apache/airflow/pull/30279#discussion_r1217854689


##########
airflow/providers/amazon/aws/sensors/batch.py:
##########
@@ -75,6 +84,32 @@ def poke(self, context: Context) -> bool:
 
         raise AirflowException(f"Batch sensor failed. Unknown AWS Batch job 
status: {state}")
 
+    def execute(self, context: Context) -> None:
+        if not self.deferrable:
+            super().execute(context=context)
+        else:
+            self.defer(
+                timeout=timedelta(seconds=self.timeout),
+                trigger=BatchSensorTrigger(
+                    job_id=self.job_id,
+                    aws_conn_id=self.aws_conn_id,
+                    region_name=self.region_name,
+                    poke_interval=self.poke_interval,
+                    max_retries=self.max_retries,
+                ),
+                method_name="execute_complete",
+            )
+
+    def execute_complete(self, context: Context, event: dict[str, Any]) -> 
None:
+        """
+        Callback for when the trigger fires - returns immediately.
+        Relies on trigger to throw an exception, otherwise it assumes 
execution was
+        successful.
+        """
+        if "status" in event and event["status"] == "error":

Review Comment:
   ```suggestion
           if "status" in event and event["status"] == "failure":
   ```
   no?



##########
airflow/providers/amazon/aws/hooks/batch_client.py:
##########
@@ -577,3 +579,77 @@ def exp(tries):
         delay = 1 + pow(tries * 0.6, 2)
         delay = min(max_interval, delay)
         return uniform(delay / 3, delay)
+
+
+class BatchClientAsyncHook(BatchClientHook, AwsBaseAsyncHook):

Review Comment:
   where we are using this?



##########
airflow/providers/amazon/aws/triggers/batch.py:
##########
@@ -105,3 +105,92 @@ async def run(self):
             yield TriggerEvent({"status": "failure", "message": "Job Failed - 
max attempts reached."})
         else:
             yield TriggerEvent({"status": "success", "job_id": self.job_id})
+
+
+class BatchSensorTrigger(BaseTrigger):
+    """
+    Checks for the status of a submitted job_id to AWS Batch until it reaches 
a failure or a success state.
+    BatchSensorTrigger is fired as deferred class with params to poll the job 
state in Triggerer
+
+    :param job_id: the job ID, to poll for job completion or not
+    :param aws_conn_id: connection id of AWS credentials / region name. If 
None,
+        credential boto3 strategy will be used
+    :param region_name: AWS region name to use
+        Override the region_name in connection (if provided)
+    :param max_retries: Number of times to poll for job state before
+        returning the current state, defaults to None
+    :param poke_interval: polling period in seconds to check for the status of 
the job
+    """
+
+    def __init__(
+        self,
+        job_id: str,
+        region_name: str | None,
+        aws_conn_id: str | None = "aws_default",
+        poke_interval: float = 5,
+        max_retries: int = 5,
+    ):
+        super().__init__()
+        self.job_id = job_id
+        self.aws_conn_id = aws_conn_id
+        self.region_name = region_name
+        self.poke_interval = poke_interval
+        self.max_retries = max_retries
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        """Serializes BatchSensorTrigger arguments and classpath."""
+        return (
+            "airflow.providers.amazon.aws.triggers.batch.BatchSensorTrigger",
+            {
+                "job_id": self.job_id,
+                "aws_conn_id": self.aws_conn_id,
+                "region_name": self.region_name,
+                "poke_interval": self.poke_interval,
+                "max_retries": self.max_retries,
+            },
+        )
+
+    @cached_property
+    def hook(self) -> BatchClientHook:
+        return BatchClientHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region_name)
+
+    async def run(self):
+        """
+        Make async connection using aiobotocore library to AWS Batch,
+        periodically poll for the Batch job status
+
+        The status that indicates job completion are: 'SUCCEEDED'|'FAILED'.
+        """
+        async with self.hook.async_conn as client:
+            waiter = self.hook.get_waiter("batch_job_complete", 
deferrable=True, client=client)
+            attempt = 0
+            while attempt < self.max_retries:
+                attempt = attempt + 1
+                try:
+                    await waiter.wait(
+                        jobs=[self.job_id],
+                        WaiterConfig={
+                            "Delay": self.poke_interval,

Review Comment:
   I'm not sure if it accepts float or not. The delay should be int



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to