This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 4eb0a410bb Use only public AwsHook's methods during IAM authorization
(#25424)
4eb0a410bb is described below
commit 4eb0a410bb2a9c3d195da0ce4e129c401ad25174
Author: Andrey Anshin <[email protected]>
AuthorDate: Tue Aug 2 21:38:51 2022 +0400
Use only public AwsHook's methods during IAM authorization (#25424)
---
airflow/providers/postgres/hooks/postgres.py | 37 +++---
airflow/providers/postgres/provider.yaml | 6 +-
.../connections/postgres.rst | 34 +++++
tests/providers/postgres/hooks/test_postgres.py | 143 +++++++++++++++------
4 files changed, 163 insertions(+), 57 deletions(-)
diff --git a/airflow/providers/postgres/hooks/postgres.py
b/airflow/providers/postgres/hooks/postgres.py
index 09c07c9b8f..747458df2f 100644
--- a/airflow/providers/postgres/hooks/postgres.py
+++ b/airflow/providers/postgres/hooks/postgres.py
@@ -174,29 +174,27 @@ class PostgresHook(DbApiHook):
or Redshift. Port is required. If none is provided, default is used for
each service
"""
- from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+ try:
+ from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+ except ImportError:
+ from airflow.exceptions import AirflowException
+
+ raise AirflowException(
+ "apache-airflow-providers-amazon not installed, run: "
+ "pip install 'apache-airflow-providers-postgres[amazon]'."
+ )
- redshift = conn.extra_dejson.get('redshift', False)
aws_conn_id = conn.extra_dejson.get('aws_conn_id', 'aws_default')
- aws_hook = AwsBaseHook(aws_conn_id, client_type='rds')
login = conn.login
- if conn.port is None:
- port = 5439 if redshift else 5432
- else:
- port = conn.port
- if redshift:
+ if conn.extra_dejson.get('redshift', False):
+ port = conn.port or 5439
# Pull the custer-identifier from the beginning of the Redshift URL
# ex. my-cluster.ccdre4hpd39h.us-east-1.redshift.amazonaws.com
returns my-cluster
cluster_identifier = conn.extra_dejson.get('cluster-identifier',
conn.host.split('.')[0])
- session, endpoint_url = aws_hook._get_credentials(region_name=None)
- client = session.client(
- "redshift",
- endpoint_url=endpoint_url,
- config=aws_hook.config,
- verify=aws_hook.verify,
- )
- cluster_creds = client.get_cluster_credentials(
- DbUser=conn.login,
+ redshift_client = AwsBaseHook(aws_conn_id=aws_conn_id,
client_type="redshift").conn
+ #
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift.html#Redshift.Client.get_cluster_credentials
+ cluster_creds = redshift_client.get_cluster_credentials(
+ DbUser=login,
DbName=self.schema or conn.schema,
ClusterIdentifier=cluster_identifier,
AutoCreate=False,
@@ -204,7 +202,10 @@ class PostgresHook(DbApiHook):
token = cluster_creds['DbPassword']
login = cluster_creds['DbUser']
else:
- token = aws_hook.conn.generate_db_auth_token(conn.host, port,
conn.login)
+ port = conn.port or 5432
+ rds_client = AwsBaseHook(aws_conn_id=aws_conn_id,
client_type="rds").conn
+ #
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.generate_db_auth_token
+ token = rds_client.generate_db_auth_token(conn.host, port,
conn.login)
return login, token, port
def get_table_primary_key(self, table: str, schema: Optional[str] =
"public") -> Optional[List[str]]:
diff --git a/airflow/providers/postgres/provider.yaml
b/airflow/providers/postgres/provider.yaml
index 2a240ac1c9..618d75a645 100644
--- a/airflow/providers/postgres/provider.yaml
+++ b/airflow/providers/postgres/provider.yaml
@@ -60,7 +60,11 @@ hooks:
python-modules:
- airflow.providers.postgres.hooks.postgres
-
connection-types:
- hook-class-name: airflow.providers.postgres.hooks.postgres.PostgresHook
connection-type: postgres
+
+additional-extras:
+ - name: amazon
+ dependencies:
+ - apache-airflow-providers-amazon>=2.6.0
diff --git a/docs/apache-airflow-providers-postgres/connections/postgres.rst
b/docs/apache-airflow-providers-postgres/connections/postgres.rst
index 3ffbae7d27..f97e99af84 100644
--- a/docs/apache-airflow-providers-postgres/connections/postgres.rst
+++ b/docs/apache-airflow-providers-postgres/connections/postgres.rst
@@ -74,6 +74,40 @@ Extra (optional)
"sslkey": "/tmp/client-key.pem"
}
+ The following extra parameters use for additional Hook configuration:
+
+ * ``iam`` - If set to ``True`` than use AWS IAM database authentication for
+ `Amazon RDS
<https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html>`__,
+ `Amazon Aurora
<https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/UsingWithRDS.IAMDBAuth.html>`__
+ or `Amazon Redshift
<https://docs.aws.amazon.com/redshift/latest/mgmt/generating-user-credentials.html>`__.
+ * ``aws_conn_id`` - AWS Connection ID which use for authentication via AWS
IAM,
+ if not specified then **aws_conn_id** is used.
+ * ``redshift`` - Used when AWS IAM database authentication enabled.
+ If set to ``True`` than authenticate to Amazon Redshift Cluster,
otherwise to Amazon RDS or Amazon Aurora.
+ * ``cluster-identifier`` - The unique identifier of the Amazon Redshift
Cluster that contains the database
+ for which you are requesting credentials. This parameter is case
sensitive.
+ If not specified than hostname from **Connection Host** is used.
+
+ Example "extras" field (Amazon RDS PostgreSQL or Amazon Aurora PostgreSQL):
+
+ .. code-block:: json
+
+ {
+ "iam": true,
+ "aws_conn_id": "aws_awesome_rds_conn"
+ }
+
+ Example "extras" field (Amazon Redshift):
+
+ .. code-block:: json
+
+ {
+ "iam": true,
+ "aws_conn_id": "aws_awesome_redshift_conn",
+ "redshift": "/tmp/server-ca.pem",
+ "cluster-identifier": "awesome-redshift-identifier"
+ }
+
When specifying the connection as URI (in :envvar:`AIRFLOW_CONN_{CONN_ID}`
variable) you should specify it
following the standard syntax of DB connections, where extras are passed
as parameters
of the URI (note that all components of the URI should be URL-encoded).
diff --git a/tests/providers/postgres/hooks/test_postgres.py
b/tests/providers/postgres/hooks/test_postgres.py
index 847260cb88..ab91b47f0a 100644
--- a/tests/providers/postgres/hooks/test_postgres.py
+++ b/tests/providers/postgres/hooks/test_postgres.py
@@ -26,12 +26,12 @@ import pytest
from airflow.models import Connection
from airflow.providers.postgres.hooks.postgres import PostgresHook
+from airflow.utils.types import NOTSET
-class TestPostgresHookConn(unittest.TestCase):
- def setUp(self):
- super().setUp()
-
+class TestPostgresHookConn:
+ @pytest.fixture(autouse=True)
+ def setup(self):
self.connection = Connection(login='login', password='password',
host='host', schema='schema')
class UnitTestPostgresHook(PostgresHook):
@@ -63,10 +63,7 @@ class TestPostgresHookConn(unittest.TestCase):
self.connection.conn_type = 'postgres'
self.db_hook.get_conn()
assert mock_connect.call_count == 1
-
- self.assertEqual(
- self.db_hook.get_uri(),
"postgresql://login:password@host/schema?client_encoding=utf-8"
- )
+ assert self.db_hook.get_uri() ==
"postgresql://login:password@host/schema?client_encoding=utf-8"
@mock.patch('airflow.providers.postgres.hooks.postgres.psycopg2.connect')
def test_get_conn_cursor(self, mock_connect):
@@ -106,13 +103,41 @@ class TestPostgresHookConn(unittest.TestCase):
)
@mock.patch('airflow.providers.postgres.hooks.postgres.psycopg2.connect')
-
@mock.patch('airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.get_client_type')
- def test_get_conn_rds_iam_postgres(self, mock_client, mock_connect):
- self.connection.extra = '{"iam":true}'
- mock_client.return_value.generate_db_auth_token.return_value =
'aws_token'
+ @mock.patch('airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook')
+ @pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"])
+ @pytest.mark.parametrize("port", [65432, 5432, None])
+ def test_get_conn_rds_iam_postgres(self, mock_aws_hook_class,
mock_connect, aws_conn_id, port):
+ mock_conn_extra = {"iam": True}
+ if aws_conn_id is not NOTSET:
+ mock_conn_extra["aws_conn_id"] = aws_conn_id
+ self.connection.extra = json.dumps(mock_conn_extra)
+ self.connection.port = port
+ mock_db_token = "aws_token"
+
+ # Mock AWS Connection
+ mock_aws_hook_instance = mock_aws_hook_class.return_value
+ mock_client = mock.MagicMock()
+ mock_client.generate_db_auth_token.return_value = mock_db_token
+ type(mock_aws_hook_instance).conn =
mock.PropertyMock(return_value=mock_client)
+
self.db_hook.get_conn()
+ # Check AwsHook initialization
+ mock_aws_hook_class.assert_called_once_with(
+ # If aws_conn_id not set than fallback to aws_default
+ aws_conn_id=aws_conn_id if aws_conn_id is not NOTSET else
"aws_default",
+ client_type="rds",
+ )
+ # Check boto3 'rds' client method `generate_db_auth_token` call args
+ mock_client.generate_db_auth_token.assert_called_once_with(
+ self.connection.host, (port or 5432), self.connection.login
+ )
+ # Check expected psycopg2 connection call args
mock_connect.assert_called_once_with(
- user='login', password='aws_token', host='host', dbname='schema',
port=5432
+ user=self.connection.login,
+ password=mock_db_token,
+ host=self.connection.host,
+ dbname=self.connection.schema,
+ port=(port or 5432),
)
@mock.patch('airflow.providers.postgres.hooks.postgres.psycopg2.connect')
@@ -124,39 +149,81 @@ class TestPostgresHookConn(unittest.TestCase):
)
@mock.patch('airflow.providers.postgres.hooks.postgres.psycopg2.connect')
- def test_get_conn_rds_iam_redshift(self, mock_connect):
- self.connection.extra = '{"iam":true, "redshift":true,
"cluster-identifier": "different-identifier"}'
- self.connection.host =
'cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com'
- login = f'IAM:{self.connection.login}'
-
- mock_session = mock.Mock()
- mock_get_cluster_credentials =
mock_session.client.return_value.get_cluster_credentials
- mock_get_cluster_credentials.return_value = {'DbPassword':
'aws_token', 'DbUser': login}
-
- aws_get_credentials_patcher = mock.patch(
-
"airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook._get_credentials",
- return_value=(mock_session, None),
+ @mock.patch('airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook')
+ @pytest.mark.parametrize("aws_conn_id", [NOTSET, None, "mock_aws_conn"])
+ @pytest.mark.parametrize("port", [5432, 5439, None])
+ @pytest.mark.parametrize(
+ "host,conn_cluster_identifier,expected_cluster_identifier",
+ [
+ (
+
'cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com',
+ NOTSET,
+ 'cluster-identifier',
+ ),
+ (
+
'cluster-identifier.ccdfre4hpd39h.us-east-1.redshift.amazonaws.com',
+ 'different-identifier',
+ 'different-identifier',
+ ),
+ ],
+ )
+ def test_get_conn_rds_iam_redshift(
+ self,
+ mock_aws_hook_class,
+ mock_connect,
+ aws_conn_id,
+ port,
+ host,
+ conn_cluster_identifier,
+ expected_cluster_identifier,
+ ):
+ mock_conn_extra = {
+ "iam": True,
+ "redshift": True,
+ }
+ if aws_conn_id is not NOTSET:
+ mock_conn_extra["aws_conn_id"] = aws_conn_id
+ if conn_cluster_identifier is not NOTSET:
+ mock_conn_extra["cluster-identifier"] = conn_cluster_identifier
+
+ self.connection.extra = json.dumps(mock_conn_extra)
+ self.connection.host = host
+ self.connection.port = port
+ mock_db_user = f'IAM:{self.connection.login}'
+ mock_db_pass = "aws_token"
+
+ # Mock AWS Connection
+ mock_aws_hook_instance = mock_aws_hook_class.return_value
+ mock_client = mock.MagicMock()
+ mock_client.get_cluster_credentials.return_value = {
+ 'DbPassword': mock_db_pass,
+ 'DbUser': mock_db_user,
+ }
+ type(mock_aws_hook_instance).conn =
mock.PropertyMock(return_value=mock_client)
+
+ self.db_hook.get_conn()
+ # Check AwsHook initialization
+ mock_aws_hook_class.assert_called_once_with(
+ # If aws_conn_id not set than fallback to aws_default
+ aws_conn_id=aws_conn_id if aws_conn_id is not NOTSET else
"aws_default",
+ client_type="redshift",
)
- get_cluster_credentials_call = mock.call(
+ # Check boto3 'redshift' client method `get_cluster_credentials` call
args
+ mock_client.get_cluster_credentials.assert_called_once_with(
DbUser=self.connection.login,
DbName=self.connection.schema,
- ClusterIdentifier="different-identifier",
+ ClusterIdentifier=expected_cluster_identifier,
AutoCreate=False,
)
-
- with aws_get_credentials_patcher:
- self.db_hook.get_conn()
- assert mock_get_cluster_credentials.mock_calls ==
[get_cluster_credentials_call]
+ # Check expected psycopg2 connection call args
mock_connect.assert_called_once_with(
- user=login, password='aws_token', host=self.connection.host,
dbname='schema', port=5439
+ user=mock_db_user,
+ password=mock_db_pass,
+ host=host,
+ dbname=self.connection.schema,
+ port=(port or 5439),
)
- # Verify that the connection object has not been mutated.
- mock_get_cluster_credentials.reset_mock()
- with aws_get_credentials_patcher:
- self.db_hook.get_conn()
- assert mock_get_cluster_credentials.mock_calls ==
[get_cluster_credentials_call]
-
def test_get_uri_from_connection_without_schema_override(self):
self.db_hook.get_connection = mock.MagicMock(
return_value=Connection(