Repository: incubator-airflow
Updated Branches:
  refs/heads/master 02292c5d9 -> 68679ae52


[AIRFLOW-2151] Allow getting the session from AwsHook

Add new tests for get_credentials and get_session
methods of AwsHook Implement get_credentials and
get_session methods in AwsHook

Closes #3079 from inytar/aws-session


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/68679ae5
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/68679ae5
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/68679ae5

Branch: refs/heads/master
Commit: 68679ae5263b8c01ea348624ae1c74a8e57a6c8d
Parents: 02292c5
Author: inytar <pietp...@fastmail.net>
Authored: Wed Feb 28 16:17:12 2018 +0100
Committer: Fokko Driesprong <fokkodriespr...@godatadriven.com>
Committed: Wed Feb 28 16:17:35 2018 +0100

----------------------------------------------------------------------
 airflow/contrib/hooks/aws_hook.py    | 16 ++++++
 tests/contrib/hooks/test_aws_hook.py | 84 ++++++++++++++++++++++++++++++-
 2 files changed, 99 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/68679ae5/airflow/contrib/hooks/aws_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/aws_hook.py 
b/airflow/contrib/hooks/aws_hook.py
index 7e1b024..2a8fa5f 100644
--- a/airflow/contrib/hooks/aws_hook.py
+++ b/airflow/contrib/hooks/aws_hook.py
@@ -151,3 +151,19 @@ class AwsHook(BaseHook):
         session, endpoint_url = self._get_credentials(region_name)
 
         return session.resource(resource_type, endpoint_url=endpoint_url)
+
+    def get_session(self, region_name=None):
+        """Get the underlying boto3.session."""
+        session, _ = self._get_credentials(region_name)
+        return session
+
+    def get_credentials(self, region_name=None):
+        """Get the underlying `botocore.Credentials` object.
+
+        This contains the attributes: access_key, secret_key and token.
+        """
+        session, _ = self._get_credentials(region_name)
+        # Credentials are refreshable, so accessing your access key / secret 
key
+        # separately can lead to a race condition.
+        # See https://stackoverflow.com/a/36291428/8283373
+        return session.get_credentials().get_frozen_credentials()

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/68679ae5/tests/contrib/hooks/test_aws_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_aws_hook.py 
b/tests/contrib/hooks/test_aws_hook.py
index aa246f0..086e486 100644
--- a/tests/contrib/hooks/test_aws_hook.py
+++ b/tests/contrib/hooks/test_aws_hook.py
@@ -17,14 +17,23 @@ import unittest
 import boto3
 
 from airflow import configuration
+from airflow.models import Connection
 from airflow.contrib.hooks.aws_hook import AwsHook
 
+try:
+    from unittest import mock
+except ImportError:
+    try:
+        import mock
+    except ImportError:
+        mock = None
 
 try:
-    from moto import mock_emr, mock_dynamodb2
+    from moto import mock_emr, mock_dynamodb2, mock_sts
 except ImportError:
     mock_emr = None
     mock_dynamodb2 = None
+    mock_sts = None
 
 
 class TestAwsHook(unittest.TestCase):
@@ -77,6 +86,79 @@ class TestAwsHook(unittest.TestCase):
 
         self.assertEqual(table.item_count, 0)
 
+    @unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not 
present')
+    @mock_dynamodb2
+    def test_get_session_returns_a_boto3_session(self):
+        hook = AwsHook(aws_conn_id='aws_default')
+        session_from_hook = hook.get_session()
+        resource_from_session = session_from_hook.resource('dynamodb')
+        table = resource_from_session.create_table(
+            TableName='test_airflow',
+            KeySchema=[
+                {
+                    'AttributeName': 'id',
+                    'KeyType': 'HASH'
+                },
+            ],
+            AttributeDefinitions=[
+                {
+                    'AttributeName': 'name',
+                    'AttributeType': 'S'
+                }
+            ],
+            ProvisionedThroughput={
+                'ReadCapacityUnits': 10,
+                'WriteCapacityUnits': 10
+            }
+        )
+
+        table.meta.client.get_waiter(
+            'table_exists').wait(TableName='test_airflow')
+
+        self.assertEqual(table.item_count, 0)
+
+    @mock.patch.object(AwsHook, 'get_connection')
+    def test_get_credentials_from_login(self, mock_get_connection):
+        mock_connection = Connection(login='aws_access_key_id',
+                                     password='aws_secret_access_key')
+        mock_get_connection.return_value = mock_connection
+        hook = AwsHook()
+        credentials_from_hook = hook.get_credentials()
+        self.assertEqual(credentials_from_hook.access_key, 'aws_access_key_id')
+        self.assertEqual(credentials_from_hook.secret_key, 
'aws_secret_access_key')
+        self.assertIsNone(credentials_from_hook.token)
+
+    @mock.patch.object(AwsHook, 'get_connection')
+    def test_get_credentials_from_extra(self, mock_get_connection):
+        mock_connection = Connection(
+            extra='{"aws_access_key_id": "aws_access_key_id",'
+            '"aws_secret_access_key": "aws_secret_access_key"}'
+        )
+        mock_get_connection.return_value = mock_connection
+        hook = AwsHook()
+        credentials_from_hook = hook.get_credentials()
+        self.assertEqual(credentials_from_hook.access_key, 'aws_access_key_id')
+        self.assertEqual(credentials_from_hook.secret_key, 
'aws_secret_access_key')
+        self.assertIsNone(credentials_from_hook.token)
+
+    @unittest.skipIf(mock_sts is None, 'mock_sts package not present')
+    @mock.patch.object(AwsHook, 'get_connection')
+    @mock_sts
+    def test_get_credentials_from_role_arn(self, mock_get_connection):
+        mock_connection = Connection(
+            extra='{"role_arn":"arn:aws:iam::123456:role/role_arn"}')
+        mock_get_connection.return_value = mock_connection
+        hook = AwsHook()
+        credentials_from_hook = hook.get_credentials()
+        self.assertEqual(credentials_from_hook.access_key, 
'AKIAIOSFODNN7EXAMPLE')
+        self.assertEqual(credentials_from_hook.secret_key,
+                         'aJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY')
+        self.assertEqual(credentials_from_hook.token,
+                         
'BQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKwRcOIfrRh'
+                         
'3c/LTo6UDdyJwOOvEVPvLXCrrrUtdnniCEXAMPLE/IvU1dYUg2RVAJBanLiHb4I'
+                         
'gRmpRV3zrkuWJOgQs8IZZaIv2BXIa2R4OlgkBN9bkUDNCJiBeb/AXlzBBko7b15'
+                         'fjrBs2+cTQtpZ3CYWFXG8C5zqx37wnOE49mRl/+OtkIKGO7fAE')
+
 
 if __name__ == '__main__':
     unittest.main()

Reply via email to