Fokko closed pull request #4011: [AIRFLOW-2216] Use profile for AWS hook if S3 
config file provided in aws_default connection extra parameters
URL: https://github.com/apache/incubator-airflow/pull/4011
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/contrib/hooks/aws_hook.py 
b/airflow/contrib/hooks/aws_hook.py
index 448de63ffe..8ce74d2b4e 100644
--- a/airflow/contrib/hooks/aws_hook.py
+++ b/airflow/contrib/hooks/aws_hook.py
@@ -97,33 +97,36 @@ def _get_credentials(self, region_name):
         if self.aws_conn_id:
             try:
                 connection_object = self.get_connection(self.aws_conn_id)
+                extra_config = connection_object.extra_dejson
                 if connection_object.login:
                     aws_access_key_id = connection_object.login
                     aws_secret_access_key = connection_object.password
 
-                elif 'aws_secret_access_key' in connection_object.extra_dejson:
-                    aws_access_key_id = connection_object.extra_dejson[
+                elif 'aws_secret_access_key' in extra_config:
+                    aws_access_key_id = extra_config[
                         'aws_access_key_id']
-                    aws_secret_access_key = connection_object.extra_dejson[
+                    aws_secret_access_key = extra_config[
                         'aws_secret_access_key']
 
-                elif 's3_config_file' in connection_object.extra_dejson:
+                elif 's3_config_file' in extra_config:
                     aws_access_key_id, aws_secret_access_key = \
                         _parse_s3_config(
-                            connection_object.extra_dejson['s3_config_file'],
-                            
connection_object.extra_dejson.get('s3_config_format'))
+                            extra_config['s3_config_file'],
+                            extra_config.get('s3_config_format'),
+                            extra_config.get('profile'))
 
                 if region_name is None:
-                    region_name = 
connection_object.extra_dejson.get('region_name')
+                    region_name = extra_config.get('region_name')
 
-                role_arn = connection_object.extra_dejson.get('role_arn')
-                external_id = connection_object.extra_dejson.get('external_id')
-                aws_account_id = 
connection_object.extra_dejson.get('aws_account_id')
-                aws_iam_role = 
connection_object.extra_dejson.get('aws_iam_role')
+                role_arn = extra_config.get('role_arn')
+                external_id = extra_config.get('external_id')
+                aws_account_id = extra_config.get('aws_account_id')
+                aws_iam_role = extra_config.get('aws_iam_role')
 
                 if role_arn is None and aws_account_id is not None and \
                         aws_iam_role is not None:
-                    role_arn = "arn:aws:iam::" + aws_account_id + ":role/" + 
aws_iam_role
+                    role_arn = "arn:aws:iam::{}:role/{}" \
+                        .format(aws_account_id, aws_iam_role)
 
                 if role_arn is not None:
                     sts_session = boto3.session.Session(
@@ -143,11 +146,12 @@ def _get_credentials(self, region_name):
                             RoleSessionName='Airflow_' + self.aws_conn_id,
                             ExternalId=external_id)
 
-                    aws_access_key_id = 
sts_response['Credentials']['AccessKeyId']
-                    aws_secret_access_key = 
sts_response['Credentials']['SecretAccessKey']
-                    aws_session_token = 
sts_response['Credentials']['SessionToken']
+                    credentials = sts_response['Credentials']
+                    aws_access_key_id = credentials['AccessKeyId']
+                    aws_secret_access_key = credentials['SecretAccessKey']
+                    aws_session_token = credentials['SessionToken']
 
-                endpoint_url = connection_object.extra_dejson.get('host')
+                endpoint_url = extra_config.get('host')
 
             except AirflowException:
                 # No connection found: fallback on boto3 credential strategy
@@ -183,7 +187,7 @@ def get_credentials(self, region_name=None):
         This contains the attributes: access_key, secret_key and token.
         """
         session, _ = self._get_credentials(region_name)
-        # Credentials are refreshable, so accessing your access key / secret 
key
-        # separately can lead to a race condition.
+        # Credentials are refreshable, so accessing your access key and
+        # secret key separately can lead to a race condition.
         # See https://stackoverflow.com/a/36291428/8283373
         return session.get_credentials().get_frozen_credentials()
diff --git a/tests/contrib/hooks/test_aws_hook.py 
b/tests/contrib/hooks/test_aws_hook.py
index d7664aca11..eaadc5fbff 100644
--- a/tests/contrib/hooks/test_aws_hook.py
+++ b/tests/contrib/hooks/test_aws_hook.py
@@ -19,6 +19,7 @@
 #
 
 import unittest
+
 import boto3
 
 from airflow import configuration
@@ -146,6 +147,26 @@ def test_get_credentials_from_extra(self, 
mock_get_connection):
         self.assertEqual(credentials_from_hook.secret_key, 
'aws_secret_access_key')
         self.assertIsNone(credentials_from_hook.token)
 
+    @mock.patch('airflow.contrib.hooks.aws_hook._parse_s3_config',
+                return_value=('aws_access_key_id', 'aws_secret_access_key'))
+    @mock.patch.object(AwsHook, 'get_connection')
+    def test_get_credentials_from_extra_with_s3_config_and_profile(
+        self, mock_get_connection, mock_parse_s3_config
+    ):
+        mock_connection = Connection(
+            extra='{"s3_config_format": "aws", '
+                  '"profile": "test", '
+                  '"s3_config_file": "aws-credentials", '
+                  '"region_name": "us-east-1"}')
+        mock_get_connection.return_value = mock_connection
+        hook = AwsHook()
+        hook._get_credentials(region_name=None)
+        mock_parse_s3_config.assert_called_with(
+            'aws-credentials',
+            'aws',
+            'test'
+        )
+
     @unittest.skipIf(mock_sts is None, 'mock_sts package not present')
     @mock.patch.object(AwsHook, 'get_connection')
     @mock_sts


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to