vincbeck commented on code in PR #28827:
URL: https://github.com/apache/airflow/pull/28827#discussion_r1068352775
##########
airflow/providers/amazon/aws/hooks/emr.py:
##########
@@ -169,6 +169,15 @@ def add_job_flow_steps(
)
return response["StepIds"]
+ def terminate_job_flow(self, job_flow_id: str) -> None:
Review Comment:
We generally try to avoid functions in hooks which just wrap boto3 api. You
can call the boto3 api directly from the operator
##########
airflow/providers/amazon/aws/operators/emr.py:
##########
@@ -538,42 +544,76 @@ def __init__(
emr_conn_id: str | None = "emr_default",
job_flow_overrides: str | dict[str, Any] | None = None,
region_name: str | None = None,
+ wait_for_completion: bool = False,
+ waiter_countdown: int | None = None,
+ waiter_check_interval_seconds: int = 60,
**kwargs,
):
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
self.emr_conn_id = emr_conn_id
self.job_flow_overrides = job_flow_overrides or {}
self.region_name = region_name
+ self.wait_for_completion = wait_for_completion
+ self.waiter_countdown = waiter_countdown
+ self.waiter_check_interval_seconds = waiter_check_interval_seconds
+
+ self._job_flow_id: str | None = None
- def execute(self, context: Context) -> str:
- emr = EmrHook(
+ @cached_property
+ def _emr_hook(self) -> EmrHook:
+ """Create and return an EmrHook."""
+ return EmrHook(
aws_conn_id=self.aws_conn_id, emr_conn_id=self.emr_conn_id,
region_name=self.region_name
)
+ def execute(self, context: Context) -> str | None:
self.log.info(
- "Creating JobFlow using aws-conn-id: %s, emr-conn-id: %s",
self.aws_conn_id, self.emr_conn_id
+ "Creating job flow using aws_conn_id: %s, emr_conn_id: %s",
self.aws_conn_id, self.emr_conn_id
)
if isinstance(self.job_flow_overrides, str):
job_flow_overrides: dict[str, Any] =
ast.literal_eval(self.job_flow_overrides)
self.job_flow_overrides = job_flow_overrides
else:
job_flow_overrides = self.job_flow_overrides
- response = emr.create_job_flow(job_flow_overrides)
+ response = self._emr_hook.create_job_flow(job_flow_overrides)
if not response["ResponseMetadata"]["HTTPStatusCode"] == 200:
- raise AirflowException(f"JobFlow creation failed: {response}")
+ raise AirflowException(f"Job flow creation failed: {response}")
else:
- job_flow_id = response["JobFlowId"]
- self.log.info("JobFlow with id %s created", job_flow_id)
+ self._job_flow_id = response["JobFlowId"]
+ self.log.info("Job flow with id %s created", self._job_flow_id)
EmrClusterLink.persist(
context=context,
operator=self,
- region_name=emr.conn_region_name,
- aws_partition=emr.conn_partition,
- job_flow_id=job_flow_id,
+ region_name=self._emr_hook.conn_region_name,
+ aws_partition=self._emr_hook.conn_partition,
+ job_flow_id=self._job_flow_id,
)
- return job_flow_id
+
+ if self.wait_for_completion:
+ # Didn't use a boto-supplied waiter because those don't
support waiting for WAITING state.
+ #
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/emr.html#waiters
+ waiter(
+
get_state_callable=self._emr_hook.get_conn().describe_cluster,
+ get_state_args={"ClusterId": self._job_flow_id},
+ parse_response=["Cluster", "Status", "State"],
+ # Cluster will be in WAITING after finishing if
KeepJobFlowAliveWhenNoSteps is True
+ desired_state={"WAITING", "TERMINATED"},
+ failure_states={"TERMINATED_WITH_ERRORS"},
+ object_type="job flow",
+ action="finished",
+ countdown=self.waiter_countdown,
+ check_interval_seconds=self.waiter_check_interval_seconds,
+ )
+
+ return self._job_flow_id
+
+ def on_kill(self) -> None:
+ """Terminate job flow."""
+ if self._job_flow_id:
+ self.log.info("Terminating job flow %s", self._job_flow_id)
+ self._emr_hook.terminate_job_flow(self._job_flow_id)
Review Comment:
```suggestion
self._emr_hook.get_conn().terminate_job_flows(JobFlowIds=[job_flow_id])
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]