This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 0faffe40b4e Azure IAM/Entra ID support for PostgresHook (#55729)
0faffe40b4e is described below
commit 0faffe40b4e561aad3cfcfca4632f1bcf3ad2bca
Author: karunpoudel <[email protected]>
AuthorDate: Sat Oct 11 04:59:37 2025 -0400
Azure IAM/Entra ID support for PostgresHook (#55729)
* init
fx
fx
pre-commit
progress
fx
fx
fx
* fx
* add config
* conf doc
* fx
* fix1
* update version
* fix doc
* prek fix
* updates
* fix
* Update pyproject.toml
* Update pyproject.toml
* fix
* try
* pass
* doc and test
---------
Co-authored-by: Karun Poudel <[email protected]>
Co-authored-by: Karun Poudel
<[email protected]>
---
dev/breeze/tests/test_selective_checks.py | 4 +-
providers/postgres/docs/configurations-ref.rst | 19 ++++++++
providers/postgres/docs/connections/postgres.rst | 15 +++++-
providers/postgres/docs/index.rst | 16 ++++---
providers/postgres/provider.yaml | 14 ++++++
providers/postgres/pyproject.toml | 4 ++
.../providers/postgres/get_provider_info.py | 14 ++++++
.../airflow/providers/postgres/hooks/postgres.py | 53 +++++++++++++++++++---
.../tests/unit/postgres/hooks/test_postgres.py | 51 +++++++++++++++++++++
9 files changed, 174 insertions(+), 16 deletions(-)
diff --git a/dev/breeze/tests/test_selective_checks.py
b/dev/breeze/tests/test_selective_checks.py
index 5c00a026921..a972f9d11ec 100644
--- a/dev/breeze/tests/test_selective_checks.py
+++ b/dev/breeze/tests/test_selective_checks.py
@@ -645,7 +645,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str,
str], stderr: str):
),
{
"selected-providers-list-as-string": "amazon common.sql
google "
- "openlineage pgvector postgres",
+ "microsoft.azure openlineage pgvector postgres",
"all-python-versions":
f"['{DEFAULT_PYTHON_MAJOR_MINOR_VERSION}']",
"all-python-versions-list-as-string":
DEFAULT_PYTHON_MAJOR_MINOR_VERSION,
"python-versions":
f"['{DEFAULT_PYTHON_MAJOR_MINOR_VERSION}']",
@@ -667,7 +667,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str,
str], stderr: str):
{
"description": "amazon...google",
"test_types": "Providers[amazon] "
-
"Providers[common.sql,openlineage,pgvector,postgres] "
+
"Providers[common.sql,microsoft.azure,openlineage,pgvector,postgres] "
"Providers[google]",
}
]
diff --git a/providers/postgres/docs/configurations-ref.rst
b/providers/postgres/docs/configurations-ref.rst
new file mode 100644
index 00000000000..a52b21b2e56
--- /dev/null
+++ b/providers/postgres/docs/configurations-ref.rst
@@ -0,0 +1,19 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+.. include::
/../../../devel-common/src/sphinx_exts/includes/providers-configurations-ref.rst
+.. include::
/../../../devel-common/src/sphinx_exts/includes/sections-and-options.rst
diff --git a/providers/postgres/docs/connections/postgres.rst
b/providers/postgres/docs/connections/postgres.rst
index 539620ad08c..3018769afbc 100644
--- a/providers/postgres/docs/connections/postgres.rst
+++ b/providers/postgres/docs/connections/postgres.rst
@@ -96,7 +96,9 @@ Extra (optional)
* ``iam`` - If set to ``True`` than use AWS IAM database authentication for
`Amazon RDS
<https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html>`__,
`Amazon Aurora
<https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/UsingWithRDS.IAMDBAuth.html>`__
- or `Amazon Redshift
<https://docs.aws.amazon.com/redshift/latest/mgmt/generating-user-credentials.html>`__.
+ `Amazon Redshift
<https://docs.aws.amazon.com/redshift/latest/mgmt/generating-user-credentials.html>`__
+ or use Microsoft Entra Authentication for
+ `Azure Postgres Flexible Server
<https://learn.microsoft.com/en-us/azure/postgresql/flexible-server/security-entra-concepts>`__.
* ``aws_conn_id`` - AWS Connection ID which use for authentication via AWS
IAM,
if not specified then **aws_default** is used.
* ``redshift`` - Used when AWS IAM database authentication enabled.
@@ -104,6 +106,8 @@ Extra (optional)
* ``cluster-identifier`` - The unique identifier of the Amazon Redshift
Cluster that contains the database
for which you are requesting credentials. This parameter is case
sensitive.
If not specified than hostname from **Connection Host** is used.
+ * ``azure_conn_id`` - Azure Connection ID to be used for authentication
via Azure Entra ID. Azure Oauth token
+ is retrieved from the azure connection which is used as password for
PostgreSQL connection. Scope for the Azure OAuth token can be set in the config
option ``azure_oauth_scope`` under the section ``[postgres]``. Requires
`apache-airflow-providers-microsoft-azure>=12.8.0`.
Example "extras" field (Amazon RDS PostgreSQL or Amazon Aurora PostgreSQL):
@@ -125,6 +129,15 @@ Extra (optional)
"cluster-identifier": "awesome-redshift-identifier"
}
+ Example "extras" field (to use Azure Entra Authentication for Postgres
Flexible Server):
+
+ .. code-block:: json
+
+ {
+ "iam": true,
+ "azure_conn_id": "azure_default_conn"
+ }
+
When specifying the connection as URI (in :envvar:`AIRFLOW_CONN_{CONN_ID}`
variable) you should specify it
following the standard syntax of DB connections, where extras are passed
as parameters
of the URI (note that all components of the URI should be URL-encoded).
diff --git a/providers/postgres/docs/index.rst
b/providers/postgres/docs/index.rst
index 800bf10d57e..54953989634 100644
--- a/providers/postgres/docs/index.rst
+++ b/providers/postgres/docs/index.rst
@@ -41,6 +41,7 @@
:maxdepth: 1
:caption: References
+ Configuration <configurations-ref>
Python API <_api/airflow/providers/postgres/index>
Dialects <dialects>
@@ -120,13 +121,14 @@ You can install such cross-provider dependencies when
installing from PyPI. For
pip install apache-airflow-providers-postgres[amazon]
-==============================================================================================================
===============
-Dependent package
Extra
-==============================================================================================================
===============
-`apache-airflow-providers-amazon
<https://airflow.apache.org/docs/apache-airflow-providers-amazon>`_
``amazon``
-`apache-airflow-providers-common-sql
<https://airflow.apache.org/docs/apache-airflow-providers-common-sql>`_
``common.sql``
-`apache-airflow-providers-openlineage
<https://airflow.apache.org/docs/apache-airflow-providers-openlineage>`_
``openlineage``
-==============================================================================================================
===============
+======================================================================================================================
===============
+Dependent package
Extra
+======================================================================================================================
===============
+`apache-airflow-providers-amazon
<https://airflow.apache.org/docs/apache-airflow-providers-amazon>`_
``amazon``
+`apache-airflow-providers-common-sql
<https://airflow.apache.org/docs/apache-airflow-providers-common-sql>`_
``common.sql``
+`apache-airflow-providers-openlineage
<https://airflow.apache.org/docs/apache-airflow-providers-openlineage>`_
``openlineage``
+`apache-airflow-providers-microsoft-azure
<https://airflow.apache.org/docs/apache-airflow-providers-microsoft-azure>`_
``microsoft.azure``
+======================================================================================================================
===============
Downloading official packages
-----------------------------
diff --git a/providers/postgres/provider.yaml b/providers/postgres/provider.yaml
index 01fa0ee058b..97c0397015a 100644
--- a/providers/postgres/provider.yaml
+++ b/providers/postgres/provider.yaml
@@ -109,3 +109,17 @@ asset-uris:
dataset-uris:
- schemes: [postgres, postgresql]
handler: airflow.providers.postgres.assets.postgres.sanitize_uri
+
+config:
+ postgres:
+ description: |
+ Configuration for Postgres hooks and operators.
+ options:
+ azure_oauth_scope:
+ description: |
+ The scope to use while retrieving Oauth token for Postgres Flexible
Server
+ from Azure Entra authentication.
+ version_added: 6.4.0
+ type: string
+ example: ~
+ default: "https://ossrdbms-aad.database.windows.net/.default"
diff --git a/providers/postgres/pyproject.toml
b/providers/postgres/pyproject.toml
index 5c27f689aee..6a2e1aee0ab 100644
--- a/providers/postgres/pyproject.toml
+++ b/providers/postgres/pyproject.toml
@@ -70,6 +70,9 @@ dependencies = [
"amazon" = [
"apache-airflow-providers-amazon>=2.6.0",
]
+"microsoft.azure" = [
+ "apache-airflow-providers-microsoft-azure"
+]
"openlineage" = [
"apache-airflow-providers-openlineage"
]
@@ -91,6 +94,7 @@ dev = [
"apache-airflow-devel-common",
"apache-airflow-providers-amazon",
"apache-airflow-providers-common-sql",
+ "apache-airflow-providers-microsoft-azure",
"apache-airflow-providers-openlineage",
# Additional devel dependencies (do not remove this line and add extra
development dependencies)
"apache-airflow-providers-common-sql[pandas]",
diff --git
a/providers/postgres/src/airflow/providers/postgres/get_provider_info.py
b/providers/postgres/src/airflow/providers/postgres/get_provider_info.py
index e33bc651039..ba50c431a1f 100644
--- a/providers/postgres/src/airflow/providers/postgres/get_provider_info.py
+++ b/providers/postgres/src/airflow/providers/postgres/get_provider_info.py
@@ -65,4 +65,18 @@ def get_provider_info():
"handler":
"airflow.providers.postgres.assets.postgres.sanitize_uri",
}
],
+ "config": {
+ "postgres": {
+ "description": "Configuration for Postgres hooks and
operators.\n",
+ "options": {
+ "azure_oauth_scope": {
+ "description": "The scope to use while retrieving
Oauth token for Postgres Flexible Server\nfrom Azure Entra authentication.\n",
+ "version_added": "6.4.0",
+ "type": "string",
+ "example": None,
+ "default":
"https://ossrdbms-aad.database.windows.net/.default",
+ }
+ },
+ }
+ },
}
diff --git
a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py
b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py
index 7a3be0ff4e3..5fecc36556a 100644
--- a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py
+++ b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py
@@ -30,6 +30,7 @@ from more_itertools import chunked
from psycopg2.extras import DictCursor, NamedTupleCursor, RealDictCursor,
execute_batch
from sqlalchemy.engine import URL
+from airflow.configuration import conf
from airflow.exceptions import (
AirflowException,
AirflowOptionalProviderFeatureException,
@@ -37,6 +38,11 @@ from airflow.exceptions import (
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.postgres.dialects.postgres import PostgresDialect
+try:
+ from airflow.sdk import Connection
+except ImportError:
+ from airflow.models.connection import Connection # type:
ignore[assignment]
+
USE_PSYCOPG3: bool
try:
import psycopg as psycopg # needed for patching in unit tests
@@ -64,11 +70,6 @@ if TYPE_CHECKING:
if USE_PSYCOPG3:
from psycopg.errors import Diagnostic
- try:
- from airflow.sdk import Connection
- except ImportError:
- from airflow.models.connection import Connection # type:
ignore[assignment]
-
CursorType: TypeAlias = DictCursor | RealDictCursor | NamedTupleCursor
CursorRow: TypeAlias = dict[str, Any] | tuple[Any, ...]
@@ -156,7 +157,9 @@ class PostgresHook(DbApiHook):
"aws_conn_id",
"sqlalchemy_scheme",
"sqlalchemy_query",
+ "azure_conn_id",
}
+ default_azure_oauth_scope =
"https://ossrdbms-aad.database.windows.net/.default"
def __init__(
self, *args, options: str | None = None, enable_log_db_messages: bool
= False, **kwargs
@@ -177,6 +180,8 @@ class PostgresHook(DbApiHook):
query = conn.extra_dejson.get("sqlalchemy_query", {})
if not isinstance(query, dict):
raise AirflowException("The parameter 'sqlalchemy_query' must be
of type dict!")
+ if conn.extra_dejson.get("iam", False):
+ conn.login, conn.password, conn.port = self.get_iam_token(conn)
return URL.create(
drivername="postgresql+psycopg" if USE_PSYCOPG3 else "postgresql",
username=self.__cast_nullable(conn.login, str),
@@ -441,8 +446,14 @@ class PostgresHook(DbApiHook):
return PostgresHook._serialize_cell_ppg2(cell, conn)
def get_iam_token(self, conn: Connection) -> tuple[str, str, int]:
+ """Get the IAM token from different identity providers."""
+ if conn.extra_dejson.get("azure_conn_id"):
+ return self.get_azure_iam_token(conn)
+ return self.get_aws_iam_token(conn)
+
+ def get_aws_iam_token(self, conn: Connection) -> tuple[str, str, int]:
"""
- Get the IAM token.
+ Get the AWS IAM token.
This uses AWSHook to retrieve a temporary password to connect to
Postgres or Redshift. Port is required. If none is provided, the
default
@@ -500,6 +511,36 @@ class PostgresHook(DbApiHook):
token = rds_client.generate_db_auth_token(conn.host, port,
conn.login)
return cast("str", login), cast("str", token), port
+ def get_azure_iam_token(self, conn: Connection) -> tuple[str, str, int]:
+ """
+ Get the Azure IAM token.
+
+ This uses AzureBaseHook to retrieve an OAUTH token to connect to
Postgres.
+ Scope for the OAuth token can be set in the config option
``azure_oauth_scope`` under the section ``[postgres]``.
+ """
+ if TYPE_CHECKING:
+ from airflow.providers.microsoft.azure.hooks.base_azure import
AzureBaseHook
+
+ azure_conn_id = conn.extra_dejson.get("azure_conn_id", "azure_default")
+ try:
+ azure_conn = Connection.get(azure_conn_id)
+ except AttributeError:
+ azure_conn = Connection.get_connection_from_secrets(azure_conn_id)
# type: ignore[attr-defined]
+ azure_base_hook: AzureBaseHook = azure_conn.get_hook()
+ scope = conf.get("postgres", "azure_oauth_scope",
fallback=self.default_azure_oauth_scope)
+ try:
+ token = azure_base_hook.get_token(scope).token
+ except AttributeError as e:
+ if e.name == "get_token" and e.obj == azure_base_hook:
+ raise AttributeError(
+ "'AzureBaseHook' object has no attribute 'get_token'. "
+ "Please upgrade
apache-airflow-providers-microsoft-azure>=12.8.0",
+ name=e.name,
+ obj=e.obj,
+ ) from e
+ raise
+ return cast("str", conn.login or azure_conn.login), token, conn.port
or 5432
+
def get_table_primary_key(self, table: str, schema: str | None = "public")
-> list[str] | None:
"""
Get the table's primary key.
diff --git a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py
b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py
index 0607edc4405..47fffe0b6bb 100644
--- a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py
+++ b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py
@@ -444,6 +444,57 @@ class TestPostgresHookConn:
port=(port or 5439),
)
+ def test_get_conn_azure_iam(self, mocker, mock_connect):
+ mock_azure_conn_id = "azure_conn1"
+ mock_db_token = "azure_token1"
+ mock_conn_extra = {"iam": True, "azure_conn_id": mock_azure_conn_id}
+ self.connection.extra = json.dumps(mock_conn_extra)
+
+ mock_connection_class =
mocker.patch("airflow.providers.postgres.hooks.postgres.Connection")
+ mock_azure_base_hook =
mock_connection_class.get.return_value.get_hook.return_value
+ mock_azure_base_hook.get_token.return_value.token = mock_db_token
+
+ self.db_hook.get_conn()
+
+ # Check AzureBaseHook initialization and get_token call args
+ mock_connection_class.get.assert_called_once_with(mock_azure_conn_id)
+
mock_azure_base_hook.get_token.assert_called_once_with(PostgresHook.default_azure_oauth_scope)
+
+ # Check expected psycopg2 connection call args
+ mock_connect.assert_called_once_with(
+ user=self.connection.login,
+ password=mock_db_token,
+ host=self.connection.host,
+ dbname=self.connection.schema,
+ port=(self.connection.port or 5432),
+ )
+
+ assert mock_db_token in self.db_hook.sqlalchemy_url
+
+ def test_get_azure_iam_token_expect_failure_on_get_token(self, mocker):
+ """Test get_azure_iam_token method gets token from provided connection
id"""
+
+ class MockAzureBaseHookWithoutGetToken:
+ def __init__(self):
+ pass
+
+ azure_conn_id = "azure_test_conn"
+ mock_connection_class =
mocker.patch("airflow.providers.postgres.hooks.postgres.Connection")
+ mock_connection_class.get.return_value.get_hook.return_value =
MockAzureBaseHookWithoutGetToken()
+
+ self.connection.extra = json.dumps({"iam": True, "azure_conn_id":
azure_conn_id})
+ with pytest.raises(
+ AttributeError,
+ match=(
+ "'AzureBaseHook' object has no attribute 'get_token'. "
+ "Please upgrade apache-airflow-providers-microsoft-azure>="
+ ),
+ ):
+ self.db_hook.get_azure_iam_token(self.connection)
+
+ # Check AzureBaseHook initialization
+ mock_connection_class.get.assert_called_once_with(azure_conn_id)
+
def test_get_uri_from_connection_without_database_override(self, mocker):
expected: str = f"postgresql{'+psycopg' if USE_PSYCOPG3 else
''}://login:password@host:1/database"
self.db_hook.get_connection = mocker.MagicMock(