This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 4eb0a410bb Use only public AwsHook's methods during IAM authorization 
(#25424)
4eb0a410bb is described below

commit 4eb0a410bb2a9c3d195da0ce4e129c401ad25174
Author: Andrey Anshin <[email protected]>
AuthorDate: Tue Aug 2 21:38:51 2022 +0400

    Use only public AwsHook's methods during IAM authorization (#25424)
---
 airflow/providers/postgres/hooks/postgres.py       |  37 +++---
 airflow/providers/postgres/provider.yaml           |   6 +-
 .../connections/postgres.rst                       |  34 +++++
 tests/providers/postgres/hooks/test_postgres.py    | 143 +++++++++++++++------
 4 files changed, 163 insertions(+), 57 deletions(-)

diff --git a/airflow/providers/postgres/hooks/postgres.py 
b/airflow/providers/postgres/hooks/postgres.py
index 09c07c9b8f..747458df2f 100644
--- a/airflow/providers/postgres/hooks/postgres.py
+++ b/airflow/providers/postgres/hooks/postgres.py
@@ -174,29 +174,27 @@ class PostgresHook(DbApiHook):
         or Redshift. Port is required. If none is provided, default is used for
         each service
         """
-        from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+        try:
+            from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+        except ImportError:
+            from airflow.exceptions import AirflowException
+
+            raise AirflowException(
+                "apache-airflow-providers-amazon not installed, run: "
+                "pip install 'apache-airflow-providers-postgres[amazon]'."
+            )
 
-        redshift = conn.extra_dejson.get('redshift', False)
         aws_conn_id = conn.extra_dejson.get('aws_conn_id', 'aws_default')
-        aws_hook = AwsBaseHook(aws_conn_id, client_type='rds')
         login = conn.login
-        if conn.port is None:
-            port = 5439 if redshift else 5432
-        else:
-            port = conn.port
-        if redshift:
+        if conn.extra_dejson.get('redshift', False):
+            port = conn.port or 5439
             # Pull the custer-identifier from the beginning of the Redshift URL
             # ex. my-cluster.ccdre4hpd39h.us-east-1.redshift.amazonaws.com 
returns my-cluster
             cluster_identifier = conn.extra_dejson.get('cluster-identifier', 
conn.host.split('.')[0])
-            session, endpoint_url = aws_hook._get_credentials(region_name=None)
-            client = session.client(
-                "redshift",
-                endpoint_url=endpoint_url,
-                config=aws_hook.config,
-                verify=aws_hook.verify,
-            )
-            cluster_creds = client.get_cluster_credentials(
-                DbUser=conn.login,
+            redshift_client = AwsBaseHook(aws_conn_id=aws_conn_id, 
client_type="redshift").conn
+            # 
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift.html#Redshift.Client.get_cluster_credentials
+            cluster_creds = redshift_client.get_cluster_credentials(
+                DbUser=login,
                 DbName=self.schema or conn.schema,
                 ClusterIdentifier=cluster_identifier,
                 AutoCreate=False,
@@ -204,7 +202,10 @@ class PostgresHook(DbApiHook):
             token = cluster_creds['DbPassword']
             login = cluster_creds['DbUser']
         else:
-            token = aws_hook.conn.generate_db_auth_token(conn.host, port, 
conn.login)
+            port = conn.port or 5432
+            rds_client = AwsBaseHook(aws_conn_id=aws_conn_id, 
client_type="rds").conn
+            # 
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.generate_db_auth_token
+            token = rds_client.generate_db_auth_token(conn.host, port, 
conn.login)
         return login, token, port
 
     def get_table_primary_key(self, table: str, schema: Optional[str] = 
"public") -> Optional[List[str]]:
diff --git a/airflow/providers/postgres/provider.yaml 
b/airflow/providers/postgres/provider.yaml
index 2a240ac1c9..618d75a645 100644
--- a/airflow/providers/postgres/provider.yaml
+++ b/airflow/providers/postgres/provider.yaml
@@ -60,7 +60,11 @@ hooks:
     python-modules:
       - airflow.providers.postgres.hooks.postgres
 
-
 connection-types:
   - hook-class-name: airflow.providers.postgres.hooks.postgres.PostgresHook
     connection-type: postgres
+
+additional-extras:
+  - name: amazon
+    dependencies:
+      - apache-airflow-providers-amazon>=2.6.0
diff --git a/docs/apache-airflow-providers-postgres/connections/postgres.rst 
b/docs/apache-airflow-providers-postgres/connections/postgres.rst
index 3ffbae7d27..f97e99af84 100644
--- a/docs/apache-airflow-providers-postgres/connections/postgres.rst
+++ b/docs/apache-airflow-providers-postgres/connections/postgres.rst
@@ -74,6 +74,40 @@ Extra (optional)
           "sslkey": "/tmp/client-key.pem"
        }
 
+    The following extra parameters use for additional Hook configuration:
+
+    * ``iam`` - If set to ``True`` than use AWS IAM database authentication for
+      `Amazon RDS 
<https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html>`__,
+      `Amazon Aurora 
<https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/UsingWithRDS.IAMDBAuth.html>`__
+      or `Amazon Redshift 
<https://docs.aws.amazon.com/redshift/latest/mgmt/generating-user-credentials.html>`__.
+    * ``aws_conn_id`` - AWS Connection ID which use for authentication via AWS 
IAM,
+      if not specified then **aws_conn_id** is used.
+    * ``redshift`` - Used when AWS IAM database authentication enabled.
+      If set to ``True`` than authenticate to Amazon Redshift Cluster, 
otherwise to Amazon RDS or Amazon Aurora.
+    * ``cluster-identifier`` - The unique identifier of the Amazon Redshift 
Cluster that contains the database
+      for which you are requesting credentials. This parameter is case 
sensitive.
+      If not specified than hostname from **Connection Host** is used.
+
+    Example "extras" field (Amazon RDS PostgreSQL or Amazon Aurora PostgreSQL):
+
+    .. code-block:: json
+
+       {
+          "iam": true,
+          "aws_conn_id": "aws_awesome_rds_conn"
+       }
+
+    Example "extras" field (Amazon Redshift):
+
+    .. code-block:: json
+
+       {
+          "iam": true,
+          "aws_conn_id": "aws_awesome_redshift_conn",
+          "redshift": "/tmp/server-ca.pem",
+          "cluster-identifier": "awesome-redshift-identifier"
+       }
+
     When specifying the connection as URI (in :envvar:`AIRFLOW_CONN_{CONN_ID}` 
variable) you should specify it
     following the standard syntax of DB connections, where extras are passed 
as parameters
     of the URI (note that all components of the URI should be URL-encoded).
diff --git a/tests/providers/postgres/hooks/test_postgres.py 
b/tests/providers/postgres/hooks/test_postgres.py
index 847260cb88..ab91b47f0a 100644
--- a/tests/providers/postgres/hooks/test_postgres.py
+++ b/tests/providers/postgres/hooks/test_postgres.py
@@ -26,12 +26,12 @@ import pytest
 
 from airflow.models import Connection
 from airflow.providers.postgres.hooks.postgres import PostgresHook
+from airflow.utils.types import NOTSET
 
 
-class TestPostgresHookConn(unittest.TestCase):
-    def setUp(self):
-        super().setUp()
-
+class TestPostgresHookConn:
+    @pytest.fixture(autouse=True)
+    def setup(self):
         self.connection = Connection(login='login', password='password', 
host='host', schema='schema')
 
         class UnitTestPostgresHook(PostgresHook):
@@ -63,10 +63,7 @@ class TestPostgresHookConn(unittest.TestCase):
         self.connection.conn_type = 'postgres'
         self.db_hook.get_conn()
         assert mock_connect.call_count == 1
-
-        self.assertEqual(
-            self.db_hook.get_uri(), 
"postgresql://login:password@host/schema?client_encoding=utf-8"
-        )
+        assert self.db_hook.get_uri() == 
"postgresql://login:password@host/schema?client_encoding=utf-8"
 
     @mock.patch('airflow.providers.postgres.hooks.postgres.psycopg2.connect')
     def test_get_conn_cursor(self, mock_connect):
@@ -106,13 +103,41 @@ class TestPostgresHookConn(unittest.TestCase):
         )
 
     @mock.patch('airflow.providers.postgres.hooks.postgres.psycopg2.connect')
-    
@mock.patch('airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.get_client_type')
-    def test_get_conn_rds_iam_postgres(self, mock_client, mock_connect):
-        self.connection.extra = '{"iam":true}'
-        mock_client.return_value.generate_db_auth_token.return_value = 
'aws_token'
+    @mock.patch('airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook')
+    @pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"])
+    @pytest.mark.parametrize("port", [65432, 5432, None])
+    def test_get_conn_rds_iam_postgres(self, mock_aws_hook_class, 
mock_connect, aws_conn_id, port):
+        mock_conn_extra = {"iam": True}
+        if aws_conn_id is not NOTSET:
+            mock_conn_extra["aws_conn_id"] = aws_conn_id
+        self.connection.extra = json.dumps(mock_conn_extra)
+        self.connection.port = port
+        mock_db_token = "aws_token"
+
+        # Mock AWS Connection
+        mock_aws_hook_instance = mock_aws_hook_class.return_value
+        mock_client = mock.MagicMock()
+        mock_client.generate_db_auth_token.return_value = mock_db_token
+        type(mock_aws_hook_instance).conn = 
mock.PropertyMock(return_value=mock_client)
+
         self.db_hook.get_conn()
+        # Check AwsHook initialization
+        mock_aws_hook_class.assert_called_once_with(
+            # If aws_conn_id not set than fallback to aws_default
+            aws_conn_id=aws_conn_id if aws_conn_id is not NOTSET else 
"aws_default",
+            client_type="rds",
+        )
+        # Check boto3 'rds' client method `generate_db_auth_token` call args
+        mock_client.generate_db_auth_token.assert_called_once_with(
+            self.connection.host, (port or 5432), self.connection.login
+        )
+        # Check expected psycopg2 connection call args
         mock_connect.assert_called_once_with(
-            user='login', password='aws_token', host='host', dbname='schema', 
port=5432
+            user=self.connection.login,
+            password=mock_db_token,
+            host=self.connection.host,
+            dbname=self.connection.schema,
+            port=(port or 5432),
         )
 
     @mock.patch('airflow.providers.postgres.hooks.postgres.psycopg2.connect')
@@ -124,39 +149,81 @@ class TestPostgresHookConn(unittest.TestCase):
         )
 
     @mock.patch('airflow.providers.postgres.hooks.postgres.psycopg2.connect')
-    def test_get_conn_rds_iam_redshift(self, mock_connect):
-        self.connection.extra = '{"iam":true, "redshift":true, 
"cluster-identifier": "different-identifier"}'
-        self.connection.host = 
'cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com'
-        login = f'IAM:{self.connection.login}'
-
-        mock_session = mock.Mock()
-        mock_get_cluster_credentials = 
mock_session.client.return_value.get_cluster_credentials
-        mock_get_cluster_credentials.return_value = {'DbPassword': 
'aws_token', 'DbUser': login}
-
-        aws_get_credentials_patcher = mock.patch(
-            
"airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook._get_credentials",
-            return_value=(mock_session, None),
+    @mock.patch('airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook')
+    @pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"])
+    @pytest.mark.parametrize("port", [5432, 5439, None])
+    @pytest.mark.parametrize(
+        "host,conn_cluster_identifier,expected_cluster_identifier",
+        [
+            (
+                
'cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com',
+                NOTSET,
+                'cluster-identifier',
+            ),
+            (
+                
'cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com',
+                'different-identifier',
+                'different-identifier',
+            ),
+        ],
+    )
+    def test_get_conn_rds_iam_redshift(
+        self,
+        mock_aws_hook_class,
+        mock_connect,
+        aws_conn_id,
+        port,
+        host,
+        conn_cluster_identifier,
+        expected_cluster_identifier,
+    ):
+        mock_conn_extra = {
+            "iam": True,
+            "redshift": True,
+        }
+        if aws_conn_id is not NOTSET:
+            mock_conn_extra["aws_conn_id"] = aws_conn_id
+        if conn_cluster_identifier is not NOTSET:
+            mock_conn_extra["cluster-identifier"] = conn_cluster_identifier
+
+        self.connection.extra = json.dumps(mock_conn_extra)
+        self.connection.host = host
+        self.connection.port = port
+        mock_db_user = f'IAM:{self.connection.login}'
+        mock_db_pass = "aws_token"
+
+        # Mock AWS Connection
+        mock_aws_hook_instance = mock_aws_hook_class.return_value
+        mock_client = mock.MagicMock()
+        mock_client.get_cluster_credentials.return_value = {
+            'DbPassword': mock_db_pass,
+            'DbUser': mock_db_user,
+        }
+        type(mock_aws_hook_instance).conn = 
mock.PropertyMock(return_value=mock_client)
+
+        self.db_hook.get_conn()
+        # Check AwsHook initialization
+        mock_aws_hook_class.assert_called_once_with(
+            # If aws_conn_id not set than fallback to aws_default
+            aws_conn_id=aws_conn_id if aws_conn_id is not NOTSET else 
"aws_default",
+            client_type="redshift",
         )
-        get_cluster_credentials_call = mock.call(
+        # Check boto3 'redshift' client method `get_cluster_credentials` call 
args
+        mock_client.get_cluster_credentials.assert_called_once_with(
             DbUser=self.connection.login,
             DbName=self.connection.schema,
-            ClusterIdentifier="different-identifier",
+            ClusterIdentifier=expected_cluster_identifier,
             AutoCreate=False,
         )
-
-        with aws_get_credentials_patcher:
-            self.db_hook.get_conn()
-        assert mock_get_cluster_credentials.mock_calls == 
[get_cluster_credentials_call]
+        # Check expected psycopg2 connection call args
         mock_connect.assert_called_once_with(
-            user=login, password='aws_token', host=self.connection.host, 
dbname='schema', port=5439
+            user=mock_db_user,
+            password=mock_db_pass,
+            host=host,
+            dbname=self.connection.schema,
+            port=(port or 5439),
         )
 
-        # Verify that the connection object has not been mutated.
-        mock_get_cluster_credentials.reset_mock()
-        with aws_get_credentials_patcher:
-            self.db_hook.get_conn()
-        assert mock_get_cluster_credentials.mock_calls == 
[get_cluster_credentials_call]
-
     def test_get_uri_from_connection_without_schema_override(self):
         self.db_hook.get_connection = mock.MagicMock(
             return_value=Connection(

Reply via email to