vandonr-amz commented on code in PR #32274:
URL: https://github.com/apache/airflow/pull/32274#discussion_r1254913020
##########
airflow/providers/amazon/aws/triggers/emr.py:
##########
@@ -102,274 +102,184 @@ async def run(self):
yield TriggerEvent({"status": "success", "message": "Steps
completed", "step_ids": self.step_ids})
-class EmrCreateJobFlowTrigger(BaseTrigger):
+class EmrCreateJobFlowTrigger(AwsBaseWaiterTrigger):
"""
Asynchronously poll the boto3 API and wait for the JobFlow to finish
executing.
:param job_flow_id: The id of the job flow to wait for.
- :param poll_interval: The amount of time in seconds to wait between
attempts.
- :param max_attempts: The maximum number of attempts to be made.
+ :param waiter_delay: The amount of time in seconds to wait between
attempts.
+ :param waiter_max_attempts: The maximum number of attempts to be made.
:param aws_conn_id: The Airflow connection used for AWS credentials.
"""
def __init__(
self,
job_flow_id: str,
- poll_interval: int,
- max_attempts: int,
- aws_conn_id: str,
+ poll_interval: int | None = None, # deprecated
+ max_attempts: int | None = None, # deprecated
+ aws_conn_id: str | None = None,
+ waiter_delay: int = 30,
+ waiter_max_attempts: int = 60,
):
- self.job_flow_id = job_flow_id
- self.poll_interval = poll_interval
- self.max_attempts = max_attempts
- self.aws_conn_id = aws_conn_id
-
- def serialize(self) -> tuple[str, dict[str, Any]]:
- return (
- self.__class__.__module__ + "." + self.__class__.__qualname__,
- {
- "job_flow_id": self.job_flow_id,
- "poll_interval": str(self.poll_interval),
- "max_attempts": str(self.max_attempts),
- "aws_conn_id": self.aws_conn_id,
- },
+ if poll_interval is not None or max_attempts is not None:
+ warnings.warn(
+ "please use waiter_delay instead of poll_interval "
+ "and waiter_max_attempts instead of max_attempts",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
+ waiter_delay = poll_interval or waiter_delay
+ waiter_max_attempts = max_attempts or waiter_max_attempts
+ super().__init__(
+ serialized_fields={"job_flow_id": job_flow_id},
+ waiter_name="job_flow_waiting",
+ waiter_args={"ClusterId": job_flow_id},
+ failure_message="JobFlow creation failed",
+ status_message="JobFlow creation in progress",
+ status_queries=[
+ "Cluster.Status.State",
+ "Cluster.Status.StateChangeReason",
+ "Cluster.Status.ErrorDetails",
+ ],
+ return_key="job_flow_id",
+ return_value=job_flow_id,
+ waiter_delay=waiter_delay,
+ waiter_max_attempts=waiter_max_attempts,
+ aws_conn_id=aws_conn_id,
)
- async def run(self):
- self.hook = EmrHook(aws_conn_id=self.aws_conn_id)
- async with self.hook.async_conn as client:
- attempt = 0
- waiter = self.hook.get_waiter("job_flow_waiting", deferrable=True,
client=client)
- while attempt < int(self.max_attempts):
- attempt = attempt + 1
- try:
- await waiter.wait(
- ClusterId=self.job_flow_id,
- WaiterConfig=prune_dict(
- {
- "Delay": self.poll_interval,
- "MaxAttempts": 1,
- }
- ),
- )
- break
- except WaiterError as error:
- if "terminal failure" in str(error):
- raise AirflowException(f"JobFlow creation failed:
{error}")
- self.log.info(
- "Status of jobflow is %s - %s",
- error.last_response["Cluster"]["Status"]["State"],
-
error.last_response["Cluster"]["Status"]["StateChangeReason"],
- )
- await asyncio.sleep(int(self.poll_interval))
- if attempt >= int(self.max_attempts):
- raise AirflowException(f"JobFlow creation failed - max attempts
reached: {self.max_attempts}")
- else:
- yield TriggerEvent(
- {
- "status": "success",
- "message": "JobFlow completed successfully",
- "job_flow_id": self.job_flow_id,
- }
- )
+ def hook(self) -> AwsGenericHook:
+ return EmrHook(aws_conn_id=self.aws_conn_id)
-class EmrTerminateJobFlowTrigger(BaseTrigger):
+class EmrTerminateJobFlowTrigger(AwsBaseWaiterTrigger):
"""
Asynchronously poll the boto3 API and wait for the JobFlow to finish
terminating.
:param job_flow_id: ID of the EMR Job Flow to terminate
- :param poll_interval: The amount of time in seconds to wait between
attempts.
- :param max_attempts: The maximum number of attempts to be made.
+ :param waiter_delay: The amount of time in seconds to wait between
attempts.
+ :param waiter_max_attempts: The maximum number of attempts to be made.
:param aws_conn_id: The Airflow connection used for AWS credentials.
"""
def __init__(
self,
job_flow_id: str,
- poll_interval: int,
- max_attempts: int,
- aws_conn_id: str,
+ poll_interval: int | None = None, # deprecated
+ max_attempts: int | None = None, # deprecated
+ aws_conn_id: str | None = None,
+ waiter_delay: int = 30,
+ waiter_max_attempts: int = 60,
):
- self.job_flow_id = job_flow_id
- self.poll_interval = poll_interval
- self.max_attempts = max_attempts
- self.aws_conn_id = aws_conn_id
-
- def serialize(self) -> tuple[str, dict[str, Any]]:
- return (
- self.__class__.__module__ + "." + self.__class__.__qualname__,
- {
- "job_flow_id": self.job_flow_id,
- "poll_interval": str(self.poll_interval),
- "max_attempts": str(self.max_attempts),
- "aws_conn_id": self.aws_conn_id,
- },
+ if poll_interval is not None or max_attempts is not None:
+ warnings.warn(
+ "please use waiter_delay instead of poll_interval "
+ "and waiter_max_attempts instead of max_attempts",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
+ waiter_delay = poll_interval or waiter_delay
+ waiter_max_attempts = max_attempts or waiter_max_attempts
+ super().__init__(
+ serialized_fields={"job_flow_id": job_flow_id},
+ waiter_name="job_flow_terminated",
+ waiter_args={"ClusterId": job_flow_id},
+ failure_message="JobFlow termination failed",
+ status_message="JobFlow termination in progress",
+ status_queries=[
+ "Cluster.Status.State",
+ "Cluster.Status.StateChangeReason",
+ "Cluster.Status.ErrorDetails",
+ ],
+ return_value=None,
+ waiter_delay=waiter_delay,
+ waiter_max_attempts=waiter_max_attempts,
+ aws_conn_id=aws_conn_id,
)
- async def run(self):
- self.hook = EmrHook(aws_conn_id=self.aws_conn_id)
- async with self.hook.async_conn as client:
- attempt = 0
- waiter = self.hook.get_waiter("job_flow_terminated",
deferrable=True, client=client)
- while attempt < int(self.max_attempts):
- attempt = attempt + 1
- try:
- await waiter.wait(
- ClusterId=self.job_flow_id,
- WaiterConfig=prune_dict(
- {
- "Delay": self.poll_interval,
- "MaxAttempts": 1,
- }
- ),
- )
- break
- except WaiterError as error:
- if "terminal failure" in str(error):
- raise AirflowException(f"JobFlow termination failed:
{error}")
- self.log.info(
- "Status of jobflow is %s - %s",
- error.last_response["Cluster"]["Status"]["State"],
-
error.last_response["Cluster"]["Status"]["StateChangeReason"],
- )
- await asyncio.sleep(int(self.poll_interval))
- if attempt >= int(self.max_attempts):
- raise AirflowException(f"JobFlow termination failed - max attempts
reached: {self.max_attempts}")
- else:
- yield TriggerEvent(
- {
- "status": "success",
- "message": "JobFlow terminated successfully",
- }
- )
+ def hook(self) -> AwsGenericHook:
+ return EmrHook(aws_conn_id=self.aws_conn_id)
-class EmrContainerTrigger(BaseTrigger):
+class EmrContainerTrigger(AwsBaseWaiterTrigger):
"""
Poll for the status of EMR container until reaches terminal state.
:param virtual_cluster_id: Reference Emr cluster id
:param job_id: job_id to check the state
:param aws_conn_id: Reference to AWS connection id
- :param poll_interval: polling period in seconds to check for the status
+ :param waiter_delay: polling period in seconds to check for the status
"""
def __init__(
self,
virtual_cluster_id: str,
job_id: str,
aws_conn_id: str = "aws_default",
- poll_interval: int = 30,
- **kwargs: Any,
+ poll_interval: int | None = None, # deprecated
+ waiter_delay: int = 30,
+ waiter_max_attempts: int = 600,
Review Comment:
ok, but in this particular case, the existing behavior was to wait forever,
so....
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]