This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 867e930 AwsBaseHook make `client_type` & `resource_type` optional
params for `get_client_type` & `get_resource_type` (#17987)
867e930 is described below
commit 867e9305f08bf9580f25430d8b6e84071c59f9e6
Author: eladkal <[email protected]>
AuthorDate: Fri Sep 3 22:29:58 2021 +0300
AwsBaseHook make `client_type` & `resource_type` optional params for
`get_client_type` & `get_resource_type` (#17987)
* AwsBaseHook make client_type & resource_type optional params for
get_client_type & get_resource_type
---
airflow/providers/amazon/aws/hooks/base_aws.py | 23 ++++++-
tests/providers/amazon/aws/hooks/test_base_aws.py | 80 ++++++++++++++++++++++-
2 files changed, 100 insertions(+), 3 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py
b/airflow/providers/amazon/aws/hooks/base_aws.py
index 57010ed..e936557 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -27,6 +27,7 @@ This module contains Base AWS Hook.
import configparser
import datetime
import logging
+import warnings
from functools import wraps
from typing import Any, Callable, Dict, Optional, Tuple, Union
@@ -433,13 +434,22 @@ class AwsBaseHook(BaseHook):
def get_client_type(
self,
- client_type: str,
+ client_type: Optional[str] = None,
region_name: Optional[str] = None,
config: Optional[Config] = None,
) -> boto3.client:
"""Get the underlying boto3 client using boto3 session"""
session, endpoint_url = self._get_credentials(region_name)
+ if client_type:
+ warnings.warn(
+ "client_type is deprecated. Set client_type from class
attribute.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ else:
+ client_type = self.client_type
+
# No AWS Operators use the config argument to this method.
# Keep backward compatibility with other users who might use it
if config is None:
@@ -449,13 +459,22 @@ class AwsBaseHook(BaseHook):
def get_resource_type(
self,
- resource_type: str,
+ resource_type: Optional[str] = None,
region_name: Optional[str] = None,
config: Optional[Config] = None,
) -> boto3.resource:
"""Get the underlying boto3 resource using boto3 session"""
session, endpoint_url = self._get_credentials(region_name)
+ if resource_type:
+ warnings.warn(
+ "resource_type is deprecated. Set resource_type from class
attribute.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ else:
+ resource_type = self.resource_type
+
# No AWS Operators use the config argument to this method.
# Keep backward compatibility with other users who might use it
if config is None:
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py
b/tests/providers/amazon/aws/hooks/test_base_aws.py
index c934387..0000136 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -111,12 +111,41 @@ class TestAwsBaseHook(unittest.TestCase):
client = boto3.client('emr', region_name='us-east-1')
if client.list_clusters()['Clusters']:
raise ValueError('AWS not properly mocked')
-
hook = AwsBaseHook(aws_conn_id='aws_default', client_type='emr')
client_from_hook = hook.get_client_type('emr')
assert client_from_hook.list_clusters()['Clusters'] == []
+ @unittest.skipIf(mock_emr is None, 'mock_emr package not present')
+ @mock_emr
+ def test_get_client_type_set_in_class_attribute(self):
+ client = boto3.client('emr', region_name='us-east-1')
+ if client.list_clusters()['Clusters']:
+ raise ValueError('AWS not properly mocked')
+ hook = AwsBaseHook(aws_conn_id='aws_default', client_type='emr')
+ client_from_hook = hook.get_client_type()
+
+ assert client_from_hook.list_clusters()['Clusters'] == []
+
+ @unittest.skipIf(mock_emr is None, 'mock_emr package not present')
+ @mock_emr
+ def test_get_client_type_overwrite(self):
+ client = boto3.client('emr', region_name='us-east-1')
+ if client.list_clusters()['Clusters']:
+ raise ValueError('AWS not properly mocked')
+ hook = AwsBaseHook(aws_conn_id='aws_default', client_type='dynamodb')
+ client_from_hook = hook.get_client_type(client_type='emr')
+ assert client_from_hook.list_clusters()['Clusters'] == []
+
+ @unittest.skipIf(mock_emr is None, 'mock_emr package not present')
+ @mock_emr
+ def test_get_client_type_deprecation_warning(self):
+ hook = AwsBaseHook(aws_conn_id='aws_default', client_type='emr')
+ warning_message = """client_type is deprecated. Set client_type from
class attribute."""
+ with pytest.warns(DeprecationWarning) as warnings:
+ hook.get_client_type(client_type='emr')
+ assert warning_message == str(warnings[0].message)
+
@unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not
present')
@mock_dynamodb2
def
test_get_resource_type_returns_a_boto3_resource_of_the_requested_type(self):
@@ -139,6 +168,55 @@ class TestAwsBaseHook(unittest.TestCase):
@unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not
present')
@mock_dynamodb2
+ def test_get_resource_type_set_in_class_attribute(self):
+ hook = AwsBaseHook(aws_conn_id='aws_default', resource_type='dynamodb')
+ resource_from_hook = hook.get_resource_type()
+
+ # this table needs to be created in production
+ table = resource_from_hook.create_table(
+ TableName='test_airflow',
+ KeySchema=[
+ {'AttributeName': 'id', 'KeyType': 'HASH'},
+ ],
+ AttributeDefinitions=[{'AttributeName': 'id', 'AttributeType':
'S'}],
+ ProvisionedThroughput={'ReadCapacityUnits': 10,
'WriteCapacityUnits': 10},
+ )
+
+
table.meta.client.get_waiter('table_exists').wait(TableName='test_airflow')
+
+ assert table.item_count == 0
+
+ @unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not
present')
+ @mock_dynamodb2
+ def test_get_resource_type_overwrite(self):
+ hook = AwsBaseHook(aws_conn_id='aws_default', resource_type='s3')
+ resource_from_hook = hook.get_resource_type('dynamodb')
+
+ # this table needs to be created in production
+ table = resource_from_hook.create_table(
+ TableName='test_airflow',
+ KeySchema=[
+ {'AttributeName': 'id', 'KeyType': 'HASH'},
+ ],
+ AttributeDefinitions=[{'AttributeName': 'id', 'AttributeType':
'S'}],
+ ProvisionedThroughput={'ReadCapacityUnits': 10,
'WriteCapacityUnits': 10},
+ )
+
+
table.meta.client.get_waiter('table_exists').wait(TableName='test_airflow')
+
+ assert table.item_count == 0
+
+ @unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not
present')
+ @mock_dynamodb2
+ def test_get_resource_deprecation_warning(self):
+ hook = AwsBaseHook(aws_conn_id='aws_default', resource_type='dynamodb')
+ warning_message = """resource_type is deprecated. Set resource_type
from class attribute."""
+ with pytest.warns(DeprecationWarning) as warnings:
+ hook.get_resource_type('dynamodb')
+ assert warning_message == str(warnings[0].message)
+
+ @unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not
present')
+ @mock_dynamodb2
def test_get_session_returns_a_boto3_session(self):
hook = AwsBaseHook(aws_conn_id='aws_default', resource_type='dynamodb')
session_from_hook = hook.get_session()