This is an automated email from the ASF dual-hosted git repository. ash pushed a commit to branch v1-10-stable in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 66382d8d68aa2c675367f007cc013613df54542c Author: Ahmad Maruf <[email protected]> AuthorDate: Fri Jun 12 15:03:51 2020 -0700 Bug fix for EmrAddStepOperator init with cluster_name error (#9235) Closes #9127 --- airflow/contrib/operators/emr_add_steps_operator.py | 8 +++++--- tests/contrib/operators/test_emr_add_steps_operator.py | 3 ++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/airflow/contrib/operators/emr_add_steps_operator.py b/airflow/contrib/operators/emr_add_steps_operator.py index 0075b1b..1917752 100644 --- a/airflow/contrib/operators/emr_add_steps_operator.py +++ b/airflow/contrib/operators/emr_add_steps_operator.py @@ -66,12 +66,14 @@ class EmrAddStepsOperator(BaseOperator): self.steps = steps def execute(self, context): - emr = EmrHook(aws_conn_id=self.aws_conn_id).get_conn() + emr_hook = EmrHook(aws_conn_id=self.aws_conn_id) - job_flow_id = self.job_flow_id + emr = emr_hook.get_conn() + job_flow_id = self.job_flow_id or emr_hook.get_cluster_id_by_name(self.job_flow_name, + self.cluster_states) if not job_flow_id: - job_flow_id = emr.get_cluster_id_by_name(self.job_flow_name, self.cluster_states) + raise AirflowException('No cluster found for name: ' + self.job_flow_name) if self.do_xcom_push: context['ti'].xcom_push(key='job_flow_id', value=job_flow_id) diff --git a/tests/contrib/operators/test_emr_add_steps_operator.py b/tests/contrib/operators/test_emr_add_steps_operator.py index e0cadfe..97ebc42 100644 --- a/tests/contrib/operators/test_emr_add_steps_operator.py +++ b/tests/contrib/operators/test_emr_add_steps_operator.py @@ -105,10 +105,11 @@ class TestEmrAddStepsOperator(unittest.TestCase): with patch('boto3.session.Session', self.boto3_session_mock): self.assertEqual(self.operator.execute(self.mock_context), ['s-2LH3R5GW3A53T']) + @patch.multiple('airflow.contrib.hooks.emr_hook.EmrHook', + get_cluster_id_by_name=MagicMock(return_value='j-1231231234')) def test_init_with_cluster_name(self): expected_job_flow_id = 'j-1231231234' - self.emr_client_mock.get_cluster_id_by_name.return_value = expected_job_flow_id self.emr_client_mock.add_job_flow_steps.return_value = ADD_STEPS_SUCCESS_RETURN with patch('boto3.session.Session', self.boto3_session_mock):
