eladkal commented on a change in pull request #20160: URL: https://github.com/apache/airflow/pull/20160#discussion_r768394156
########## File path: airflow/providers/amazon/aws/sensors/emr.py ########## @@ -0,0 +1,347 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, Iterable, Optional + +try: + from functools import cached_property +except ImportError: + from cached_property import cached_property + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.emr import EmrHook +from airflow.providers.amazon.aws.hooks.emr_containers import EMRContainerHook +from airflow.sensors.base import BaseSensorOperator + + +class EmrBaseSensor(BaseSensorOperator): + """ + Contains general sensor behavior for EMR. + + Subclasses should implement following methods: + - ``get_emr_response()`` + - ``state_from_response()`` + - ``failure_message_from_response()`` + + Subclasses should set ``target_states`` and ``failed_states`` fields. + + :param aws_conn_id: aws connection to uses + :type aws_conn_id: str + """ + + ui_color = '#66c3ff' + + def __init__(self, *, aws_conn_id: str = 'aws_default', **kwargs): + super().__init__(**kwargs) + self.aws_conn_id = aws_conn_id + self.target_states: Optional[Iterable[str]] = None # will be set in subclasses + self.failed_states: Optional[Iterable[str]] = None # will be set in subclasses + self.hook: Optional[EmrHook] = None + + def get_hook(self) -> EmrHook: + """Get EmrHook""" + if self.hook: + return self.hook + + self.hook = EmrHook(aws_conn_id=self.aws_conn_id) + return self.hook + + def poke(self, context): + response = self.get_emr_response() + + if not response['ResponseMetadata']['HTTPStatusCode'] == 200: + self.log.info('Bad HTTP response: %s', response) + return False + + state = self.state_from_response(response) + self.log.info('Job flow currently %s', state) + + if state in self.target_states: + return True + + if state in self.failed_states: + final_message = 'EMR job failed' + failure_message = self.failure_message_from_response(response) + if failure_message: + final_message += ' ' + failure_message + raise AirflowException(final_message) + + return False + + def get_emr_response(self) -> Dict[str, Any]: + """ + Make an API call with boto3 and get response. + + :return: response + :rtype: dict[str, Any] + """ + raise NotImplementedError('Please implement get_emr_response() in subclass') + + @staticmethod + def state_from_response(response: Dict[str, Any]) -> str: + """ + Get state from response dictionary. + + :param response: response from AWS API + :type response: dict[str, Any] + :return: state + :rtype: str + """ + raise NotImplementedError('Please implement state_from_response() in subclass') + + @staticmethod + def failure_message_from_response(response: Dict[str, Any]) -> Optional[str]: + """ + Get failure message from response dictionary. + + :param response: response from AWS API + :type response: dict[str, Any] + :return: failure message + :rtype: Optional[str] + """ + raise NotImplementedError('Please implement failure_message_from_response() in subclass') + + +class EMRContainerSensor(BaseSensorOperator): Review comment: Can we use this opportunity to also rename the class to `EmrContainerSensor` for consistency? It will save us double deprecation in the future. ########## File path: airflow/providers/amazon/aws/operators/emr.py ########## @@ -0,0 +1,393 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import ast +from datetime import datetime +from typing import Any, Dict, List, Optional, Union +from uuid import uuid4 + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator, BaseOperatorLink, TaskInstance +from airflow.providers.amazon.aws.hooks.emr import EmrHook + +try: + from functools import cached_property +except ImportError: + from cached_property import cached_property + +from airflow.providers.amazon.aws.hooks.emr_containers import EMRContainerHook + + +class EmrAddStepsOperator(BaseOperator): + """ + An operator that adds steps to an existing EMR job_flow. + + :param job_flow_id: id of the JobFlow to add steps to. (templated) + :type job_flow_id: Optional[str] + :param job_flow_name: name of the JobFlow to add steps to. Use as an alternative to passing + job_flow_id. will search for id of JobFlow with matching name in one of the states in + param cluster_states. Exactly one cluster like this should exist or will fail. (templated) + :type job_flow_name: Optional[str] + :param cluster_states: Acceptable cluster states when searching for JobFlow id by job_flow_name. + (templated) + :type cluster_states: list + :param aws_conn_id: aws connection to uses + :type aws_conn_id: str + :param steps: boto3 style steps or reference to a steps file (must be '.json') to + be added to the jobflow. (templated) + :type steps: list|str + :param do_xcom_push: if True, job_flow_id is pushed to XCom with key job_flow_id. + :type do_xcom_push: bool + """ + + template_fields = ['job_flow_id', 'job_flow_name', 'cluster_states', 'steps'] + template_ext = ('.json',) + template_fields_renderers = {"steps": "json"} + ui_color = '#f9c915' + + def __init__( + self, + *, + job_flow_id: Optional[str] = None, + job_flow_name: Optional[str] = None, + cluster_states: Optional[List[str]] = None, + aws_conn_id: str = 'aws_default', + steps: Optional[Union[List[dict], str]] = None, + **kwargs, + ): + if kwargs.get('xcom_push') is not None: + raise AirflowException("'xcom_push' was deprecated, use 'do_xcom_push' instead") + if not (job_flow_id is None) ^ (job_flow_name is None): + raise AirflowException('Exactly one of job_flow_id or job_flow_name must be specified.') + super().__init__(**kwargs) + cluster_states = cluster_states or [] + steps = steps or [] + self.aws_conn_id = aws_conn_id + self.job_flow_id = job_flow_id + self.job_flow_name = job_flow_name + self.cluster_states = cluster_states + self.steps = steps + + def execute(self, context: Dict[str, Any]) -> List[str]: + emr_hook = EmrHook(aws_conn_id=self.aws_conn_id) + + emr = emr_hook.get_conn() + + job_flow_id = self.job_flow_id or emr_hook.get_cluster_id_by_name( + str(self.job_flow_name), self.cluster_states + ) + + if not job_flow_id: + raise AirflowException(f'No cluster found for name: {self.job_flow_name}') + + if self.do_xcom_push: + context['ti'].xcom_push(key='job_flow_id', value=job_flow_id) + + self.log.info('Adding steps to %s', job_flow_id) + + # steps may arrive as a string representing a list + # e.g. if we used XCom or a file then: steps="[{ step1 }, { step2 }]" + steps = self.steps + if isinstance(steps, str): + steps = ast.literal_eval(steps) + + response = emr.add_job_flow_steps(JobFlowId=job_flow_id, Steps=steps) + + if not response['ResponseMetadata']['HTTPStatusCode'] == 200: + raise AirflowException(f'Adding steps failed: {response}') + else: + self.log.info('Steps %s added to JobFlow', response['StepIds']) + return response['StepIds'] + + +class EMRContainerOperator(BaseOperator): Review comment: same comment -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
