pankajastro commented on code in PR #30279:
URL: https://github.com/apache/airflow/pull/30279#discussion_r1226888817
##########
airflow/providers/amazon/aws/sensors/batch.py:
##########
@@ -75,6 +87,32 @@ def poke(self, context: Context) -> bool:
raise AirflowException(f"Batch sensor failed. Unknown AWS Batch job
status: {state}")
+ def execute(self, context: Context) -> None:
+ if not self.deferrable:
+ super().execute(context=context)
+ else:
+ self.defer(
+ timeout=timedelta(seconds=self.timeout),
Review Comment:
I think we should calculate this timeout based on poke_interval and
max_retries if given, wdyt?
##########
airflow/providers/amazon/aws/triggers/batch.py:
##########
@@ -105,3 +105,92 @@ async def run(self):
yield TriggerEvent({"status": "failure", "message": "Job Failed -
max attempts reached."})
else:
yield TriggerEvent({"status": "success", "job_id": self.job_id})
+
+
+class BatchSensorTrigger(BaseTrigger):
+ """
+ Checks for the status of a submitted job_id to AWS Batch until it reaches
a failure or a success state.
+ BatchSensorTrigger is fired as deferred class with params to poll the job
state in Triggerer.
+
+ :param job_id: the job ID, to poll for job completion or not
+ :param region_name: AWS region name to use
+ Override the region_name in connection (if provided)
+ :param aws_conn_id: connection id of AWS credentials / region name. If
None,
+ credential boto3 strategy will be used
+ :param poke_interval: polling period in seconds to check for the status of
the job
+ :param max_retries: Number of times to poll for job state before
+ returning the current state, defaults to None
+ """
+
+ def __init__(
+ self,
+ job_id: str,
+ region_name: str | None,
+ aws_conn_id: str | None = "aws_default",
+ poke_interval: float = 5,
+ max_retries: int = 5,
+ ):
+ super().__init__()
+ self.job_id = job_id
+ self.aws_conn_id = aws_conn_id
+ self.region_name = region_name
+ self.poke_interval = poke_interval
+ self.max_retries = max_retries
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serializes BatchSensorTrigger arguments and classpath."""
+ return (
+ "airflow.providers.amazon.aws.triggers.batch.BatchSensorTrigger",
+ {
+ "job_id": self.job_id,
+ "aws_conn_id": self.aws_conn_id,
+ "region_name": self.region_name,
+ "poke_interval": self.poke_interval,
+ "max_retries": self.max_retries,
+ },
+ )
+
+ @cached_property
+ def hook(self) -> BatchClientHook:
+ return BatchClientHook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name)
+
+ async def run(self):
+ """
+ Make async connection using aiobotocore library to AWS Batch,
+ periodically poll for the Batch job status.
+
+ The status that indicates job completion are: 'SUCCEEDED'|'FAILED'.
+ """
+ async with self.hook.async_conn as client:
+ waiter = self.hook.get_waiter("batch_job_complete",
deferrable=True, client=client)
+ attempt = 0
+ while attempt < self.max_retries:
Review Comment:
just posting some discussion here
https://github.com/apache/airflow/pull/30945#discussion_r1218325194. What would
happen if the trigger stops and picked up again on another machine then this
max_retries would not make much sense. wdyt?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]