Repository: incubator-airflow Updated Branches: refs/heads/master 1bf541180 -> 82a65eeca
[AIRFLOW-1968][AIRFLOW-1520] Add role_arn and aws_account_id/aws_iam_role support back to aws hook In PR2532 (AIRFLOW-1520), the AWS credential code was refactored into a general AWS hook. When that change was made, the existing assume role code was removed, leaving only ID/Secret credentials as an option. Our dags rely on role assumption to access external S3 buckets, so this code re-adds role assumption via STS. Additionally, in order to make this a bit easier, I changed _get_credentials to return a functioning boto3 session which is used by the public methods to initialize clients/resources/whatever. This seemed a better route than adding another returnval in an already long list. Closes #2918 from CannibalVox/aws_hook_support_sts Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/82a65eec Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/82a65eec Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/82a65eec Branch: refs/heads/master Commit: 82a65eeca9bc654cec4d0f356c3db70bc8ab6838 Parents: 1bf5411 Author: Stephen Baynham <[email protected]> Authored: Mon Feb 5 14:15:23 2018 +0100 Committer: Bolke de Bruin <[email protected]> Committed: Mon Feb 5 14:15:23 2018 +0100 ---------------------------------------------------------------------- airflow/contrib/hooks/aws_hook.py | 52 +++++++++++++------- .../operators/test_emr_add_steps_operator.py | 7 ++- .../test_emr_create_job_flow_operator.py | 7 ++- .../test_emr_terminate_job_flow_operator.py | 7 ++- .../contrib/sensors/test_emr_job_flow_sensor.py | 7 ++- tests/contrib/sensors/test_emr_step_sensor.py | 12 +++-- 6 files changed, 61 insertions(+), 31 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/82a65eec/airflow/contrib/hooks/aws_hook.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/hooks/aws_hook.py b/airflow/contrib/hooks/aws_hook.py index fac0ab7..7e1b024 100644 --- a/airflow/contrib/hooks/aws_hook.py +++ b/airflow/contrib/hooks/aws_hook.py @@ -84,6 +84,7 @@ class AwsHook(BaseHook): def _get_credentials(self, region_name): aws_access_key_id = None aws_secret_access_key = None + aws_session_token = None endpoint_url = None if self.aws_conn_id: @@ -105,6 +106,29 @@ class AwsHook(BaseHook): if region_name is None: region_name = connection_object.extra_dejson.get('region_name') + role_arn = connection_object.extra_dejson.get('role_arn') + aws_account_id = connection_object.extra_dejson.get('aws_account_id') + aws_iam_role = connection_object.extra_dejson.get('aws_iam_role') + + if role_arn is None and aws_account_id is not None and \ + aws_iam_role is not None: + + role_arn = "arn:aws:iam::" + aws_account_id + ":role/" + aws_iam_role + + if role_arn is not None: + sts_session = boto3.session.Session( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + region_name=region_name) + + sts_client = sts_session.client('sts') + sts_response = sts_client.assume_role( + RoleArn=role_arn, + RoleSessionName='Airflow_' + self.aws_conn_id) + aws_access_key_id = sts_response['Credentials']['AccessKeyId'] + aws_secret_access_key = sts_response['Credentials']['SecretAccessKey'] + aws_session_token = sts_response['Credentials']['SessionToken'] + endpoint_url = connection_object.extra_dejson.get('host') except AirflowException: @@ -112,28 +136,18 @@ class AwsHook(BaseHook): # http://boto3.readthedocs.io/en/latest/guide/configuration.html pass - return aws_access_key_id, aws_secret_access_key, region_name, endpoint_url + return boto3.session.Session( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + region_name=region_name), endpoint_url def get_client_type(self, client_type, region_name=None): - aws_access_key_id, aws_secret_access_key, region_name, endpoint_url = \ - self._get_credentials(region_name) + session, endpoint_url = self._get_credentials(region_name) - return boto3.client( - client_type, - region_name=region_name, - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - endpoint_url=endpoint_url - ) + return session.client(client_type, endpoint_url=endpoint_url) def get_resource_type(self, resource_type, region_name=None): - aws_access_key_id, aws_secret_access_key, region_name, endpoint_url = \ - self._get_credentials(region_name) + session, endpoint_url = self._get_credentials(region_name) - return boto3.resource( - resource_type, - region_name=region_name, - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - endpoint_url=endpoint_url - ) + return session.resource(resource_type, endpoint_url=endpoint_url) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/82a65eec/tests/contrib/operators/test_emr_add_steps_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/test_emr_add_steps_operator.py b/tests/contrib/operators/test_emr_add_steps_operator.py index 37f9a4c..141e986 100644 --- a/tests/contrib/operators/test_emr_add_steps_operator.py +++ b/tests/contrib/operators/test_emr_add_steps_operator.py @@ -34,12 +34,15 @@ class TestEmrAddStepsOperator(unittest.TestCase): mock_emr_client = MagicMock() mock_emr_client.add_job_flow_steps.return_value = ADD_STEPS_SUCCESS_RETURN + mock_emr_session = MagicMock() + mock_emr_session.client.return_value = mock_emr_client + # Mock out the emr_client creator - self.boto3_client_mock = MagicMock(return_value=mock_emr_client) + self.boto3_session_mock = MagicMock(return_value=mock_emr_session) def test_execute_adds_steps_to_the_job_flow_and_returns_step_ids(self): - with patch('boto3.client', self.boto3_client_mock): + with patch('boto3.session.Session', self.boto3_session_mock): operator = EmrAddStepsOperator( task_id='test_task', http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/82a65eec/tests/contrib/operators/test_emr_create_job_flow_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/test_emr_create_job_flow_operator.py b/tests/contrib/operators/test_emr_create_job_flow_operator.py index 4aa4cd2..9120aea 100644 --- a/tests/contrib/operators/test_emr_create_job_flow_operator.py +++ b/tests/contrib/operators/test_emr_create_job_flow_operator.py @@ -34,12 +34,15 @@ class TestEmrCreateJobFlowOperator(unittest.TestCase): mock_emr_client = MagicMock() mock_emr_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN + mock_emr_session = MagicMock() + mock_emr_session.client.return_value = mock_emr_client + # Mock out the emr_client creator - self.boto3_client_mock = MagicMock(return_value=mock_emr_client) + self.boto3_session_mock = MagicMock(return_value=mock_emr_session) def test_execute_uses_the_emr_config_to_create_a_cluster_and_returns_job_id(self): - with patch('boto3.client', self.boto3_client_mock): + with patch('boto3.session.Session', self.boto3_session_mock): operator = EmrCreateJobFlowOperator( task_id='test_task', http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/82a65eec/tests/contrib/operators/test_emr_terminate_job_flow_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/test_emr_terminate_job_flow_operator.py b/tests/contrib/operators/test_emr_terminate_job_flow_operator.py index 94c0124..2ecbb1c 100644 --- a/tests/contrib/operators/test_emr_terminate_job_flow_operator.py +++ b/tests/contrib/operators/test_emr_terminate_job_flow_operator.py @@ -33,12 +33,15 @@ class TestEmrTerminateJobFlowOperator(unittest.TestCase): mock_emr_client = MagicMock() mock_emr_client.terminate_job_flows.return_value = TERMINATE_SUCCESS_RETURN + mock_emr_session = MagicMock() + mock_emr_session.client.return_value = mock_emr_client + # Mock out the emr_client creator - self.boto3_client_mock = MagicMock(return_value=mock_emr_client) + self.boto3_session_mock = MagicMock(return_value=mock_emr_session) def test_execute_terminates_the_job_flow_and_does_not_error(self): - with patch('boto3.client', self.boto3_client_mock): + with patch('boto3.session.Session', self.boto3_session_mock): operator = EmrTerminateJobFlowOperator( task_id='test_task', http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/82a65eec/tests/contrib/sensors/test_emr_job_flow_sensor.py ---------------------------------------------------------------------- diff --git a/tests/contrib/sensors/test_emr_job_flow_sensor.py b/tests/contrib/sensors/test_emr_job_flow_sensor.py index f993786..b73c7ea 100644 --- a/tests/contrib/sensors/test_emr_job_flow_sensor.py +++ b/tests/contrib/sensors/test_emr_job_flow_sensor.py @@ -96,12 +96,15 @@ class TestEmrJobFlowSensor(unittest.TestCase): DESCRIBE_CLUSTER_TERMINATED_RETURN ] + mock_emr_session = MagicMock() + mock_emr_session.client.return_value = self.mock_emr_client + # Mock out the emr_client creator - self.boto3_client_mock = MagicMock(return_value=self.mock_emr_client) + self.boto3_session_mock = MagicMock(return_value=mock_emr_session) def test_execute_calls_with_the_job_flow_id_until_it_reaches_a_terminal_state(self): - with patch('boto3.client', self.boto3_client_mock): + with patch('boto3.session.Session', self.boto3_session_mock): operator = EmrJobFlowSensor( task_id='test_task', http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/82a65eec/tests/contrib/sensors/test_emr_step_sensor.py ---------------------------------------------------------------------- diff --git a/tests/contrib/sensors/test_emr_step_sensor.py b/tests/contrib/sensors/test_emr_step_sensor.py index b5d43fb..192acd5 100644 --- a/tests/contrib/sensors/test_emr_step_sensor.py +++ b/tests/contrib/sensors/test_emr_step_sensor.py @@ -121,15 +121,19 @@ class TestEmrStepSensor(unittest.TestCase): aws_conn_id='aws_default', ) + mock_emr_session = MagicMock() + mock_emr_session.client.return_value = self.emr_client_mock + + # Mock out the emr_client creator + self.boto3_session_mock = MagicMock(return_value=mock_emr_session) + def test_step_completed(self): self.emr_client_mock.describe_step.side_effect = [ DESCRIBE_JOB_STEP_RUNNING_RETURN, DESCRIBE_JOB_STEP_COMPLETED_RETURN ] - self.boto3_client_mock = MagicMock(return_value=self.emr_client_mock) - - with patch('boto3.client', self.boto3_client_mock): + with patch('boto3.session.Session', self.boto3_session_mock): self.sensor.execute(None) self.assertEqual(self.emr_client_mock.describe_step.call_count, 2) @@ -146,7 +150,7 @@ class TestEmrStepSensor(unittest.TestCase): self.boto3_client_mock = MagicMock(return_value=self.emr_client_mock) - with patch('boto3.client', self.boto3_client_mock): + with patch('boto3.session.Session', self.boto3_session_mock): self.assertRaises(AirflowException, self.sensor.execute, None)
