This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 867e930  AwsBaseHook make `client_type` & `resource_type` optional 
params for `get_client_type` & `get_resource_type` (#17987)
867e930 is described below

commit 867e9305f08bf9580f25430d8b6e84071c59f9e6
Author: eladkal <[email protected]>
AuthorDate: Fri Sep 3 22:29:58 2021 +0300

    AwsBaseHook make `client_type` & `resource_type` optional params for 
`get_client_type` & `get_resource_type` (#17987)
    
    * AwsBaseHook make client_type & resource_type optional params for 
get_client_type & get_resource_type
---
 airflow/providers/amazon/aws/hooks/base_aws.py    | 23 ++++++-
 tests/providers/amazon/aws/hooks/test_base_aws.py | 80 ++++++++++++++++++++++-
 2 files changed, 100 insertions(+), 3 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py 
b/airflow/providers/amazon/aws/hooks/base_aws.py
index 57010ed..e936557 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -27,6 +27,7 @@ This module contains Base AWS Hook.
 import configparser
 import datetime
 import logging
+import warnings
 from functools import wraps
 from typing import Any, Callable, Dict, Optional, Tuple, Union
 
@@ -433,13 +434,22 @@ class AwsBaseHook(BaseHook):
 
     def get_client_type(
         self,
-        client_type: str,
+        client_type: Optional[str] = None,
         region_name: Optional[str] = None,
         config: Optional[Config] = None,
     ) -> boto3.client:
         """Get the underlying boto3 client using boto3 session"""
         session, endpoint_url = self._get_credentials(region_name)
 
+        if client_type:
+            warnings.warn(
+                "client_type is deprecated. Set client_type from class 
attribute.",
+                DeprecationWarning,
+                stacklevel=2,
+            )
+        else:
+            client_type = self.client_type
+
         # No AWS Operators use the config argument to this method.
         # Keep backward compatibility with other users who might use it
         if config is None:
@@ -449,13 +459,22 @@ class AwsBaseHook(BaseHook):
 
     def get_resource_type(
         self,
-        resource_type: str,
+        resource_type: Optional[str] = None,
         region_name: Optional[str] = None,
         config: Optional[Config] = None,
     ) -> boto3.resource:
         """Get the underlying boto3 resource using boto3 session"""
         session, endpoint_url = self._get_credentials(region_name)
 
+        if resource_type:
+            warnings.warn(
+                "resource_type is deprecated. Set resource_type from class 
attribute.",
+                DeprecationWarning,
+                stacklevel=2,
+            )
+        else:
+            resource_type = self.resource_type
+
         # No AWS Operators use the config argument to this method.
         # Keep backward compatibility with other users who might use it
         if config is None:
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py 
b/tests/providers/amazon/aws/hooks/test_base_aws.py
index c934387..0000136 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -111,12 +111,41 @@ class TestAwsBaseHook(unittest.TestCase):
         client = boto3.client('emr', region_name='us-east-1')
         if client.list_clusters()['Clusters']:
             raise ValueError('AWS not properly mocked')
-
         hook = AwsBaseHook(aws_conn_id='aws_default', client_type='emr')
         client_from_hook = hook.get_client_type('emr')
 
         assert client_from_hook.list_clusters()['Clusters'] == []
 
+    @unittest.skipIf(mock_emr is None, 'mock_emr package not present')
+    @mock_emr
+    def test_get_client_type_set_in_class_attribute(self):
+        client = boto3.client('emr', region_name='us-east-1')
+        if client.list_clusters()['Clusters']:
+            raise ValueError('AWS not properly mocked')
+        hook = AwsBaseHook(aws_conn_id='aws_default', client_type='emr')
+        client_from_hook = hook.get_client_type()
+
+        assert client_from_hook.list_clusters()['Clusters'] == []
+
+    @unittest.skipIf(mock_emr is None, 'mock_emr package not present')
+    @mock_emr
+    def test_get_client_type_overwrite(self):
+        client = boto3.client('emr', region_name='us-east-1')
+        if client.list_clusters()['Clusters']:
+            raise ValueError('AWS not properly mocked')
+        hook = AwsBaseHook(aws_conn_id='aws_default', client_type='dynamodb')
+        client_from_hook = hook.get_client_type(client_type='emr')
+        assert client_from_hook.list_clusters()['Clusters'] == []
+
+    @unittest.skipIf(mock_emr is None, 'mock_emr package not present')
+    @mock_emr
+    def test_get_client_type_deprecation_warning(self):
+        hook = AwsBaseHook(aws_conn_id='aws_default', client_type='emr')
+        warning_message = """client_type is deprecated. Set client_type from 
class attribute."""
+        with pytest.warns(DeprecationWarning) as warnings:
+            hook.get_client_type(client_type='emr')
+            assert warning_message == str(warnings[0].message)
+
     @unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not 
present')
     @mock_dynamodb2
     def 
test_get_resource_type_returns_a_boto3_resource_of_the_requested_type(self):
@@ -139,6 +168,55 @@ class TestAwsBaseHook(unittest.TestCase):
 
     @unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not 
present')
     @mock_dynamodb2
+    def test_get_resource_type_set_in_class_attribute(self):
+        hook = AwsBaseHook(aws_conn_id='aws_default', resource_type='dynamodb')
+        resource_from_hook = hook.get_resource_type()
+
+        # this table needs to be created in production
+        table = resource_from_hook.create_table(
+            TableName='test_airflow',
+            KeySchema=[
+                {'AttributeName': 'id', 'KeyType': 'HASH'},
+            ],
+            AttributeDefinitions=[{'AttributeName': 'id', 'AttributeType': 
'S'}],
+            ProvisionedThroughput={'ReadCapacityUnits': 10, 
'WriteCapacityUnits': 10},
+        )
+
+        
table.meta.client.get_waiter('table_exists').wait(TableName='test_airflow')
+
+        assert table.item_count == 0
+
+    @unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not 
present')
+    @mock_dynamodb2
+    def test_get_resource_type_overwrite(self):
+        hook = AwsBaseHook(aws_conn_id='aws_default', resource_type='s3')
+        resource_from_hook = hook.get_resource_type('dynamodb')
+
+        # this table needs to be created in production
+        table = resource_from_hook.create_table(
+            TableName='test_airflow',
+            KeySchema=[
+                {'AttributeName': 'id', 'KeyType': 'HASH'},
+            ],
+            AttributeDefinitions=[{'AttributeName': 'id', 'AttributeType': 
'S'}],
+            ProvisionedThroughput={'ReadCapacityUnits': 10, 
'WriteCapacityUnits': 10},
+        )
+
+        
table.meta.client.get_waiter('table_exists').wait(TableName='test_airflow')
+
+        assert table.item_count == 0
+
+    @unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not 
present')
+    @mock_dynamodb2
+    def test_get_resource_deprecation_warning(self):
+        hook = AwsBaseHook(aws_conn_id='aws_default', resource_type='dynamodb')
+        warning_message = """resource_type is deprecated. Set resource_type 
from class attribute."""
+        with pytest.warns(DeprecationWarning) as warnings:
+            hook.get_resource_type('dynamodb')
+            assert warning_message == str(warnings[0].message)
+
+    @unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not 
present')
+    @mock_dynamodb2
     def test_get_session_returns_a_boto3_session(self):
         hook = AwsBaseHook(aws_conn_id='aws_default', resource_type='dynamodb')
         session_from_hook = hook.get_session()

Reply via email to