Repository: incubator-airflow Updated Branches: refs/heads/master 02292c5d9 -> 68679ae52
[AIRFLOW-2151] Allow getting the session from AwsHook Add new tests for get_credentials and get_session methods of AwsHook Implement get_credentials and get_session methods in AwsHook Closes #3079 from inytar/aws-session Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/68679ae5 Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/68679ae5 Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/68679ae5 Branch: refs/heads/master Commit: 68679ae5263b8c01ea348624ae1c74a8e57a6c8d Parents: 02292c5 Author: inytar <pietp...@fastmail.net> Authored: Wed Feb 28 16:17:12 2018 +0100 Committer: Fokko Driesprong <fokkodriespr...@godatadriven.com> Committed: Wed Feb 28 16:17:35 2018 +0100 ---------------------------------------------------------------------- airflow/contrib/hooks/aws_hook.py | 16 ++++++ tests/contrib/hooks/test_aws_hook.py | 84 ++++++++++++++++++++++++++++++- 2 files changed, 99 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/68679ae5/airflow/contrib/hooks/aws_hook.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/hooks/aws_hook.py b/airflow/contrib/hooks/aws_hook.py index 7e1b024..2a8fa5f 100644 --- a/airflow/contrib/hooks/aws_hook.py +++ b/airflow/contrib/hooks/aws_hook.py @@ -151,3 +151,19 @@ class AwsHook(BaseHook): session, endpoint_url = self._get_credentials(region_name) return session.resource(resource_type, endpoint_url=endpoint_url) + + def get_session(self, region_name=None): + """Get the underlying boto3.session.""" + session, _ = self._get_credentials(region_name) + return session + + def get_credentials(self, region_name=None): + """Get the underlying `botocore.Credentials` object. + + This contains the attributes: access_key, secret_key and token. + """ + session, _ = self._get_credentials(region_name) + # Credentials are refreshable, so accessing your access key / secret key + # separately can lead to a race condition. + # See https://stackoverflow.com/a/36291428/8283373 + return session.get_credentials().get_frozen_credentials() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/68679ae5/tests/contrib/hooks/test_aws_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/test_aws_hook.py b/tests/contrib/hooks/test_aws_hook.py index aa246f0..086e486 100644 --- a/tests/contrib/hooks/test_aws_hook.py +++ b/tests/contrib/hooks/test_aws_hook.py @@ -17,14 +17,23 @@ import unittest import boto3 from airflow import configuration +from airflow.models import Connection from airflow.contrib.hooks.aws_hook import AwsHook +try: + from unittest import mock +except ImportError: + try: + import mock + except ImportError: + mock = None try: - from moto import mock_emr, mock_dynamodb2 + from moto import mock_emr, mock_dynamodb2, mock_sts except ImportError: mock_emr = None mock_dynamodb2 = None + mock_sts = None class TestAwsHook(unittest.TestCase): @@ -77,6 +86,79 @@ class TestAwsHook(unittest.TestCase): self.assertEqual(table.item_count, 0) + @unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not present') + @mock_dynamodb2 + def test_get_session_returns_a_boto3_session(self): + hook = AwsHook(aws_conn_id='aws_default') + session_from_hook = hook.get_session() + resource_from_session = session_from_hook.resource('dynamodb') + table = resource_from_session.create_table( + TableName='test_airflow', + KeySchema=[ + { + 'AttributeName': 'id', + 'KeyType': 'HASH' + }, + ], + AttributeDefinitions=[ + { + 'AttributeName': 'name', + 'AttributeType': 'S' + } + ], + ProvisionedThroughput={ + 'ReadCapacityUnits': 10, + 'WriteCapacityUnits': 10 + } + ) + + table.meta.client.get_waiter( + 'table_exists').wait(TableName='test_airflow') + + self.assertEqual(table.item_count, 0) + + @mock.patch.object(AwsHook, 'get_connection') + def test_get_credentials_from_login(self, mock_get_connection): + mock_connection = Connection(login='aws_access_key_id', + password='aws_secret_access_key') + mock_get_connection.return_value = mock_connection + hook = AwsHook() + credentials_from_hook = hook.get_credentials() + self.assertEqual(credentials_from_hook.access_key, 'aws_access_key_id') + self.assertEqual(credentials_from_hook.secret_key, 'aws_secret_access_key') + self.assertIsNone(credentials_from_hook.token) + + @mock.patch.object(AwsHook, 'get_connection') + def test_get_credentials_from_extra(self, mock_get_connection): + mock_connection = Connection( + extra='{"aws_access_key_id": "aws_access_key_id",' + '"aws_secret_access_key": "aws_secret_access_key"}' + ) + mock_get_connection.return_value = mock_connection + hook = AwsHook() + credentials_from_hook = hook.get_credentials() + self.assertEqual(credentials_from_hook.access_key, 'aws_access_key_id') + self.assertEqual(credentials_from_hook.secret_key, 'aws_secret_access_key') + self.assertIsNone(credentials_from_hook.token) + + @unittest.skipIf(mock_sts is None, 'mock_sts package not present') + @mock.patch.object(AwsHook, 'get_connection') + @mock_sts + def test_get_credentials_from_role_arn(self, mock_get_connection): + mock_connection = Connection( + extra='{"role_arn":"arn:aws:iam::123456:role/role_arn"}') + mock_get_connection.return_value = mock_connection + hook = AwsHook() + credentials_from_hook = hook.get_credentials() + self.assertEqual(credentials_from_hook.access_key, 'AKIAIOSFODNN7EXAMPLE') + self.assertEqual(credentials_from_hook.secret_key, + 'aJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY') + self.assertEqual(credentials_from_hook.token, + 'BQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKwRcOIfrRh' + '3c/LTo6UDdyJwOOvEVPvLXCrrrUtdnniCEXAMPLE/IvU1dYUg2RVAJBanLiHb4I' + 'gRmpRV3zrkuWJOgQs8IZZaIv2BXIa2R4OlgkBN9bkUDNCJiBeb/AXlzBBko7b15' + 'fjrBs2+cTQtpZ3CYWFXG8C5zqx37wnOE49mRl/+OtkIKGO7fAE') + if __name__ == '__main__': unittest.main()