Repository: incubator-airflow
Updated Branches:
  refs/heads/master 1bf541180 -> 82a65eeca


[AIRFLOW-1968][AIRFLOW-1520] Add role_arn and aws_account_id/aws_iam_role 
support back to aws hook

In PR2532 (AIRFLOW-1520), the AWS credential code
was refactored into a general
 AWS hook.  When that change was made, the existing
assume role code was
 removed, leaving only ID/Secret credentials as an
option.  Our dags rely on
 role assumption to access external S3 buckets, so
this code re-adds role
 assumption via STS.

Additionally, in order to make this a bit easier,
I changed _get_credentials to
 return a functioning boto3 session which is used
by the public methods to
 initialize clients/resources/whatever.  This
seemed a better route than
 adding another returnval in an already long list.

Closes #2918 from CannibalVox/aws_hook_support_sts


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/82a65eec
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/82a65eec
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/82a65eec

Branch: refs/heads/master
Commit: 82a65eeca9bc654cec4d0f356c3db70bc8ab6838
Parents: 1bf5411
Author: Stephen Baynham <[email protected]>
Authored: Mon Feb 5 14:15:23 2018 +0100
Committer: Bolke de Bruin <[email protected]>
Committed: Mon Feb 5 14:15:23 2018 +0100

----------------------------------------------------------------------
 airflow/contrib/hooks/aws_hook.py               | 52 +++++++++++++-------
 .../operators/test_emr_add_steps_operator.py    |  7 ++-
 .../test_emr_create_job_flow_operator.py        |  7 ++-
 .../test_emr_terminate_job_flow_operator.py     |  7 ++-
 .../contrib/sensors/test_emr_job_flow_sensor.py |  7 ++-
 tests/contrib/sensors/test_emr_step_sensor.py   | 12 +++--
 6 files changed, 61 insertions(+), 31 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/82a65eec/airflow/contrib/hooks/aws_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/aws_hook.py 
b/airflow/contrib/hooks/aws_hook.py
index fac0ab7..7e1b024 100644
--- a/airflow/contrib/hooks/aws_hook.py
+++ b/airflow/contrib/hooks/aws_hook.py
@@ -84,6 +84,7 @@ class AwsHook(BaseHook):
     def _get_credentials(self, region_name):
         aws_access_key_id = None
         aws_secret_access_key = None
+        aws_session_token = None
         endpoint_url = None
 
         if self.aws_conn_id:
@@ -105,6 +106,29 @@ class AwsHook(BaseHook):
                 if region_name is None:
                     region_name = 
connection_object.extra_dejson.get('region_name')
 
+                role_arn = connection_object.extra_dejson.get('role_arn')
+                aws_account_id = 
connection_object.extra_dejson.get('aws_account_id')
+                aws_iam_role = 
connection_object.extra_dejson.get('aws_iam_role')
+
+                if role_arn is None and aws_account_id is not None and \
+                        aws_iam_role is not None:
+
+                    role_arn = "arn:aws:iam::" + aws_account_id + ":role/" + 
aws_iam_role
+
+                if role_arn is not None:
+                    sts_session = boto3.session.Session(
+                        aws_access_key_id=aws_access_key_id,
+                        aws_secret_access_key=aws_secret_access_key,
+                        region_name=region_name)
+
+                    sts_client = sts_session.client('sts')
+                    sts_response = sts_client.assume_role(
+                        RoleArn=role_arn,
+                        RoleSessionName='Airflow_' + self.aws_conn_id)
+                    aws_access_key_id = 
sts_response['Credentials']['AccessKeyId']
+                    aws_secret_access_key = 
sts_response['Credentials']['SecretAccessKey']
+                    aws_session_token = 
sts_response['Credentials']['SessionToken']
+
                 endpoint_url = connection_object.extra_dejson.get('host')
 
             except AirflowException:
@@ -112,28 +136,18 @@ class AwsHook(BaseHook):
                 # 
http://boto3.readthedocs.io/en/latest/guide/configuration.html
                 pass
 
-        return aws_access_key_id, aws_secret_access_key, region_name, 
endpoint_url
+        return boto3.session.Session(
+            aws_access_key_id=aws_access_key_id,
+            aws_secret_access_key=aws_secret_access_key,
+            aws_session_token=aws_session_token,
+            region_name=region_name), endpoint_url
 
     def get_client_type(self, client_type, region_name=None):
-        aws_access_key_id, aws_secret_access_key, region_name, endpoint_url = \
-            self._get_credentials(region_name)
+        session, endpoint_url = self._get_credentials(region_name)
 
-        return boto3.client(
-            client_type,
-            region_name=region_name,
-            aws_access_key_id=aws_access_key_id,
-            aws_secret_access_key=aws_secret_access_key,
-            endpoint_url=endpoint_url
-        )
+        return session.client(client_type, endpoint_url=endpoint_url)
 
     def get_resource_type(self, resource_type, region_name=None):
-        aws_access_key_id, aws_secret_access_key, region_name, endpoint_url = \
-            self._get_credentials(region_name)
+        session, endpoint_url = self._get_credentials(region_name)
 
-        return boto3.resource(
-            resource_type,
-            region_name=region_name,
-            aws_access_key_id=aws_access_key_id,
-            aws_secret_access_key=aws_secret_access_key,
-            endpoint_url=endpoint_url
-        )
+        return session.resource(resource_type, endpoint_url=endpoint_url)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/82a65eec/tests/contrib/operators/test_emr_add_steps_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_emr_add_steps_operator.py 
b/tests/contrib/operators/test_emr_add_steps_operator.py
index 37f9a4c..141e986 100644
--- a/tests/contrib/operators/test_emr_add_steps_operator.py
+++ b/tests/contrib/operators/test_emr_add_steps_operator.py
@@ -34,12 +34,15 @@ class TestEmrAddStepsOperator(unittest.TestCase):
         mock_emr_client = MagicMock()
         mock_emr_client.add_job_flow_steps.return_value = 
ADD_STEPS_SUCCESS_RETURN
 
+        mock_emr_session = MagicMock()
+        mock_emr_session.client.return_value = mock_emr_client
+
         # Mock out the emr_client creator
-        self.boto3_client_mock = MagicMock(return_value=mock_emr_client)
+        self.boto3_session_mock = MagicMock(return_value=mock_emr_session)
 
 
     def test_execute_adds_steps_to_the_job_flow_and_returns_step_ids(self):
-        with patch('boto3.client', self.boto3_client_mock):
+        with patch('boto3.session.Session', self.boto3_session_mock):
 
             operator = EmrAddStepsOperator(
                 task_id='test_task',

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/82a65eec/tests/contrib/operators/test_emr_create_job_flow_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_emr_create_job_flow_operator.py 
b/tests/contrib/operators/test_emr_create_job_flow_operator.py
index 4aa4cd2..9120aea 100644
--- a/tests/contrib/operators/test_emr_create_job_flow_operator.py
+++ b/tests/contrib/operators/test_emr_create_job_flow_operator.py
@@ -34,12 +34,15 @@ class TestEmrCreateJobFlowOperator(unittest.TestCase):
         mock_emr_client = MagicMock()
         mock_emr_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN
 
+        mock_emr_session = MagicMock()
+        mock_emr_session.client.return_value = mock_emr_client
+
         # Mock out the emr_client creator
-        self.boto3_client_mock = MagicMock(return_value=mock_emr_client)
+        self.boto3_session_mock = MagicMock(return_value=mock_emr_session)
 
 
     def 
test_execute_uses_the_emr_config_to_create_a_cluster_and_returns_job_id(self):
-        with patch('boto3.client', self.boto3_client_mock):
+        with patch('boto3.session.Session', self.boto3_session_mock):
 
             operator = EmrCreateJobFlowOperator(
                 task_id='test_task',

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/82a65eec/tests/contrib/operators/test_emr_terminate_job_flow_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_emr_terminate_job_flow_operator.py 
b/tests/contrib/operators/test_emr_terminate_job_flow_operator.py
index 94c0124..2ecbb1c 100644
--- a/tests/contrib/operators/test_emr_terminate_job_flow_operator.py
+++ b/tests/contrib/operators/test_emr_terminate_job_flow_operator.py
@@ -33,12 +33,15 @@ class TestEmrTerminateJobFlowOperator(unittest.TestCase):
         mock_emr_client = MagicMock()
         mock_emr_client.terminate_job_flows.return_value = 
TERMINATE_SUCCESS_RETURN
 
+        mock_emr_session = MagicMock()
+        mock_emr_session.client.return_value = mock_emr_client
+
         # Mock out the emr_client creator
-        self.boto3_client_mock = MagicMock(return_value=mock_emr_client)
+        self.boto3_session_mock = MagicMock(return_value=mock_emr_session)
 
 
     def test_execute_terminates_the_job_flow_and_does_not_error(self):
-        with patch('boto3.client', self.boto3_client_mock):
+        with patch('boto3.session.Session', self.boto3_session_mock):
 
             operator = EmrTerminateJobFlowOperator(
                 task_id='test_task',

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/82a65eec/tests/contrib/sensors/test_emr_job_flow_sensor.py
----------------------------------------------------------------------
diff --git a/tests/contrib/sensors/test_emr_job_flow_sensor.py 
b/tests/contrib/sensors/test_emr_job_flow_sensor.py
index f993786..b73c7ea 100644
--- a/tests/contrib/sensors/test_emr_job_flow_sensor.py
+++ b/tests/contrib/sensors/test_emr_job_flow_sensor.py
@@ -96,12 +96,15 @@ class TestEmrJobFlowSensor(unittest.TestCase):
             DESCRIBE_CLUSTER_TERMINATED_RETURN
         ]
 
+        mock_emr_session = MagicMock()
+        mock_emr_session.client.return_value = self.mock_emr_client
+
         # Mock out the emr_client creator
-        self.boto3_client_mock = MagicMock(return_value=self.mock_emr_client)
+        self.boto3_session_mock = MagicMock(return_value=mock_emr_session)
 
 
     def 
test_execute_calls_with_the_job_flow_id_until_it_reaches_a_terminal_state(self):
-        with patch('boto3.client', self.boto3_client_mock):
+        with patch('boto3.session.Session', self.boto3_session_mock):
 
             operator = EmrJobFlowSensor(
                 task_id='test_task',

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/82a65eec/tests/contrib/sensors/test_emr_step_sensor.py
----------------------------------------------------------------------
diff --git a/tests/contrib/sensors/test_emr_step_sensor.py 
b/tests/contrib/sensors/test_emr_step_sensor.py
index b5d43fb..192acd5 100644
--- a/tests/contrib/sensors/test_emr_step_sensor.py
+++ b/tests/contrib/sensors/test_emr_step_sensor.py
@@ -121,15 +121,19 @@ class TestEmrStepSensor(unittest.TestCase):
             aws_conn_id='aws_default',
         )
 
+        mock_emr_session = MagicMock()
+        mock_emr_session.client.return_value = self.emr_client_mock
+
+        # Mock out the emr_client creator
+        self.boto3_session_mock = MagicMock(return_value=mock_emr_session)
+
     def test_step_completed(self):
         self.emr_client_mock.describe_step.side_effect = [
             DESCRIBE_JOB_STEP_RUNNING_RETURN,
             DESCRIBE_JOB_STEP_COMPLETED_RETURN
         ]
 
-        self.boto3_client_mock = MagicMock(return_value=self.emr_client_mock)
-
-        with patch('boto3.client', self.boto3_client_mock):
+        with patch('boto3.session.Session', self.boto3_session_mock):
             self.sensor.execute(None)
 
             self.assertEqual(self.emr_client_mock.describe_step.call_count, 2)
@@ -146,7 +150,7 @@ class TestEmrStepSensor(unittest.TestCase):
 
         self.boto3_client_mock = MagicMock(return_value=self.emr_client_mock)
 
-        with patch('boto3.client', self.boto3_client_mock):
+        with patch('boto3.session.Session', self.boto3_session_mock):
             self.assertRaises(AirflowException, self.sensor.execute, None)
 
 

Reply via email to