This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0faffe40b4e Azure IAM/Entra ID support for PostgresHook (#55729)
0faffe40b4e is described below

commit 0faffe40b4e561aad3cfcfca4632f1bcf3ad2bca
Author: karunpoudel <[email protected]>
AuthorDate: Sat Oct 11 04:59:37 2025 -0400

    Azure IAM/Entra ID support for PostgresHook (#55729)
    
    * init
    
    fx
    
    fx
    
    pre-commit
    
    progress
    
    fx
    
    fx
    
    fx
    
    * fx
    
    * add config
    
    * conf doc
    
    * fx
    
    * fix1
    
    * update version
    
    * fix doc
    
    * prek fix
    
    * updates
    
    * fix
    
    * Update pyproject.toml
    
    * Update pyproject.toml
    
    * fix
    
    * try
    
    * pass
    
    * doc and test
    
    ---------
    
    Co-authored-by: Karun Poudel <[email protected]>
    Co-authored-by: Karun Poudel 
<[email protected]>
---
 dev/breeze/tests/test_selective_checks.py          |  4 +-
 providers/postgres/docs/configurations-ref.rst     | 19 ++++++++
 providers/postgres/docs/connections/postgres.rst   | 15 +++++-
 providers/postgres/docs/index.rst                  | 16 ++++---
 providers/postgres/provider.yaml                   | 14 ++++++
 providers/postgres/pyproject.toml                  |  4 ++
 .../providers/postgres/get_provider_info.py        | 14 ++++++
 .../airflow/providers/postgres/hooks/postgres.py   | 53 +++++++++++++++++++---
 .../tests/unit/postgres/hooks/test_postgres.py     | 51 +++++++++++++++++++++
 9 files changed, 174 insertions(+), 16 deletions(-)

diff --git a/dev/breeze/tests/test_selective_checks.py 
b/dev/breeze/tests/test_selective_checks.py
index 5c00a026921..a972f9d11ec 100644
--- a/dev/breeze/tests/test_selective_checks.py
+++ b/dev/breeze/tests/test_selective_checks.py
@@ -645,7 +645,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, 
str], stderr: str):
                 ),
                 {
                     "selected-providers-list-as-string": "amazon common.sql 
google "
-                    "openlineage pgvector postgres",
+                    "microsoft.azure openlineage pgvector postgres",
                     "all-python-versions": 
f"['{DEFAULT_PYTHON_MAJOR_MINOR_VERSION}']",
                     "all-python-versions-list-as-string": 
DEFAULT_PYTHON_MAJOR_MINOR_VERSION,
                     "python-versions": 
f"['{DEFAULT_PYTHON_MAJOR_MINOR_VERSION}']",
@@ -667,7 +667,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, 
str], stderr: str):
                             {
                                 "description": "amazon...google",
                                 "test_types": "Providers[amazon] "
-                                
"Providers[common.sql,openlineage,pgvector,postgres] "
+                                
"Providers[common.sql,microsoft.azure,openlineage,pgvector,postgres] "
                                 "Providers[google]",
                             }
                         ]
diff --git a/providers/postgres/docs/configurations-ref.rst 
b/providers/postgres/docs/configurations-ref.rst
new file mode 100644
index 00000000000..a52b21b2e56
--- /dev/null
+++ b/providers/postgres/docs/configurations-ref.rst
@@ -0,0 +1,19 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+ ..   http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+.. include:: 
/../../../devel-common/src/sphinx_exts/includes/providers-configurations-ref.rst
+.. include:: 
/../../../devel-common/src/sphinx_exts/includes/sections-and-options.rst
diff --git a/providers/postgres/docs/connections/postgres.rst 
b/providers/postgres/docs/connections/postgres.rst
index 539620ad08c..3018769afbc 100644
--- a/providers/postgres/docs/connections/postgres.rst
+++ b/providers/postgres/docs/connections/postgres.rst
@@ -96,7 +96,9 @@ Extra (optional)
     * ``iam`` - If set to ``True`` than use AWS IAM database authentication for
       `Amazon RDS 
<https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html>`__,
       `Amazon Aurora 
<https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/UsingWithRDS.IAMDBAuth.html>`__
-      or `Amazon Redshift 
<https://docs.aws.amazon.com/redshift/latest/mgmt/generating-user-credentials.html>`__.
+      `Amazon Redshift 
<https://docs.aws.amazon.com/redshift/latest/mgmt/generating-user-credentials.html>`__
+      or use Microsoft Entra Authentication for
+      `Azure Postgres Flexible Server 
<https://learn.microsoft.com/en-us/azure/postgresql/flexible-server/security-entra-concepts>`__.
     * ``aws_conn_id`` - AWS Connection ID which use for authentication via AWS 
IAM,
       if not specified then **aws_default** is used.
     * ``redshift`` - Used when AWS IAM database authentication enabled.
@@ -104,6 +106,8 @@ Extra (optional)
     * ``cluster-identifier`` - The unique identifier of the Amazon Redshift 
Cluster that contains the database
       for which you are requesting credentials. This parameter is case 
sensitive.
       If not specified than hostname from **Connection Host** is used.
+    * ``azure_conn_id`` - Azure Connection ID to be used for authentication 
via Azure Entra ID. Azure Oauth token
+      is retrieved from the azure connection which is used as password for 
PostgreSQL connection. Scope for the Azure OAuth token can be set in the config 
option ``azure_oauth_scope`` under the section ``[postgres]``. Requires 
`apache-airflow-providers-microsoft-azure>=12.8.0`.
 
     Example "extras" field (Amazon RDS PostgreSQL or Amazon Aurora PostgreSQL):
 
@@ -125,6 +129,15 @@ Extra (optional)
           "cluster-identifier": "awesome-redshift-identifier"
        }
 
+    Example "extras" field (to use Azure Entra Authentication for Postgres 
Flexible Server):
+
+    .. code-block:: json
+
+       {
+          "iam": true,
+          "azure_conn_id": "azure_default_conn"
+       }
+
     When specifying the connection as URI (in :envvar:`AIRFLOW_CONN_{CONN_ID}` 
variable) you should specify it
     following the standard syntax of DB connections, where extras are passed 
as parameters
     of the URI (note that all components of the URI should be URL-encoded).
diff --git a/providers/postgres/docs/index.rst 
b/providers/postgres/docs/index.rst
index 800bf10d57e..54953989634 100644
--- a/providers/postgres/docs/index.rst
+++ b/providers/postgres/docs/index.rst
@@ -41,6 +41,7 @@
     :maxdepth: 1
     :caption: References
 
+    Configuration <configurations-ref>
     Python API <_api/airflow/providers/postgres/index>
     Dialects <dialects>
 
@@ -120,13 +121,14 @@ You can install such cross-provider dependencies when 
installing from PyPI. For
     pip install apache-airflow-providers-postgres[amazon]
 
 
-==============================================================================================================
  ===============
-Dependent package                                                              
                                 Extra
-==============================================================================================================
  ===============
-`apache-airflow-providers-amazon 
<https://airflow.apache.org/docs/apache-airflow-providers-amazon>`_            
``amazon``
-`apache-airflow-providers-common-sql 
<https://airflow.apache.org/docs/apache-airflow-providers-common-sql>`_    
``common.sql``
-`apache-airflow-providers-openlineage 
<https://airflow.apache.org/docs/apache-airflow-providers-openlineage>`_  
``openlineage``
-==============================================================================================================
  ===============
+======================================================================================================================
  ===============
+Dependent package                                                              
                                         Extra
+======================================================================================================================
  ===============
+`apache-airflow-providers-amazon 
<https://airflow.apache.org/docs/apache-airflow-providers-amazon>`_             
       ``amazon``
+`apache-airflow-providers-common-sql 
<https://airflow.apache.org/docs/apache-airflow-providers-common-sql>`_         
   ``common.sql``
+`apache-airflow-providers-openlineage 
<https://airflow.apache.org/docs/apache-airflow-providers-openlineage>`_        
  ``openlineage``
+`apache-airflow-providers-microsoft-azure 
<https://airflow.apache.org/docs/apache-airflow-providers-microsoft-azure>`_  
``microsoft.azure``
+======================================================================================================================
  ===============
 
 Downloading official packages
 -----------------------------
diff --git a/providers/postgres/provider.yaml b/providers/postgres/provider.yaml
index 01fa0ee058b..97c0397015a 100644
--- a/providers/postgres/provider.yaml
+++ b/providers/postgres/provider.yaml
@@ -109,3 +109,17 @@ asset-uris:
 dataset-uris:
   - schemes: [postgres, postgresql]
     handler: airflow.providers.postgres.assets.postgres.sanitize_uri
+
+config:
+  postgres:
+    description: |
+      Configuration for Postgres hooks and operators.
+    options:
+      azure_oauth_scope:
+        description: |
+          The scope to use while retrieving Oauth token for Postgres Flexible 
Server
+          from Azure Entra authentication.
+        version_added: 6.4.0
+        type: string
+        example: ~
+        default: "https://ossrdbms-aad.database.windows.net/.default";
diff --git a/providers/postgres/pyproject.toml 
b/providers/postgres/pyproject.toml
index 5c27f689aee..6a2e1aee0ab 100644
--- a/providers/postgres/pyproject.toml
+++ b/providers/postgres/pyproject.toml
@@ -70,6 +70,9 @@ dependencies = [
 "amazon" = [
     "apache-airflow-providers-amazon>=2.6.0",
 ]
+"microsoft.azure" = [
+    "apache-airflow-providers-microsoft-azure"
+]
 "openlineage" = [
     "apache-airflow-providers-openlineage"
 ]
@@ -91,6 +94,7 @@ dev = [
     "apache-airflow-devel-common",
     "apache-airflow-providers-amazon",
     "apache-airflow-providers-common-sql",
+    "apache-airflow-providers-microsoft-azure",
     "apache-airflow-providers-openlineage",
     # Additional devel dependencies (do not remove this line and add extra 
development dependencies)
     "apache-airflow-providers-common-sql[pandas]",
diff --git 
a/providers/postgres/src/airflow/providers/postgres/get_provider_info.py 
b/providers/postgres/src/airflow/providers/postgres/get_provider_info.py
index e33bc651039..ba50c431a1f 100644
--- a/providers/postgres/src/airflow/providers/postgres/get_provider_info.py
+++ b/providers/postgres/src/airflow/providers/postgres/get_provider_info.py
@@ -65,4 +65,18 @@ def get_provider_info():
                 "handler": 
"airflow.providers.postgres.assets.postgres.sanitize_uri",
             }
         ],
+        "config": {
+            "postgres": {
+                "description": "Configuration for Postgres hooks and 
operators.\n",
+                "options": {
+                    "azure_oauth_scope": {
+                        "description": "The scope to use while retrieving 
Oauth token for Postgres Flexible Server\nfrom Azure Entra authentication.\n",
+                        "version_added": "6.4.0",
+                        "type": "string",
+                        "example": None,
+                        "default": 
"https://ossrdbms-aad.database.windows.net/.default";,
+                    }
+                },
+            }
+        },
     }
diff --git 
a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py 
b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py
index 7a3be0ff4e3..5fecc36556a 100644
--- a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py
+++ b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py
@@ -30,6 +30,7 @@ from more_itertools import chunked
 from psycopg2.extras import DictCursor, NamedTupleCursor, RealDictCursor, 
execute_batch
 from sqlalchemy.engine import URL
 
+from airflow.configuration import conf
 from airflow.exceptions import (
     AirflowException,
     AirflowOptionalProviderFeatureException,
@@ -37,6 +38,11 @@ from airflow.exceptions import (
 from airflow.providers.common.sql.hooks.sql import DbApiHook
 from airflow.providers.postgres.dialects.postgres import PostgresDialect
 
+try:
+    from airflow.sdk import Connection
+except ImportError:
+    from airflow.models.connection import Connection  # type: 
ignore[assignment]
+
 USE_PSYCOPG3: bool
 try:
     import psycopg as psycopg  # needed for patching in unit tests
@@ -64,11 +70,6 @@ if TYPE_CHECKING:
     if USE_PSYCOPG3:
         from psycopg.errors import Diagnostic
 
-    try:
-        from airflow.sdk import Connection
-    except ImportError:
-        from airflow.models.connection import Connection  # type: 
ignore[assignment]
-
 CursorType: TypeAlias = DictCursor | RealDictCursor | NamedTupleCursor
 CursorRow: TypeAlias = dict[str, Any] | tuple[Any, ...]
 
@@ -156,7 +157,9 @@ class PostgresHook(DbApiHook):
         "aws_conn_id",
         "sqlalchemy_scheme",
         "sqlalchemy_query",
+        "azure_conn_id",
     }
+    default_azure_oauth_scope = 
"https://ossrdbms-aad.database.windows.net/.default";
 
     def __init__(
         self, *args, options: str | None = None, enable_log_db_messages: bool 
= False, **kwargs
@@ -177,6 +180,8 @@ class PostgresHook(DbApiHook):
         query = conn.extra_dejson.get("sqlalchemy_query", {})
         if not isinstance(query, dict):
             raise AirflowException("The parameter 'sqlalchemy_query' must be 
of type dict!")
+        if conn.extra_dejson.get("iam", False):
+            conn.login, conn.password, conn.port = self.get_iam_token(conn)
         return URL.create(
             drivername="postgresql+psycopg" if USE_PSYCOPG3 else "postgresql",
             username=self.__cast_nullable(conn.login, str),
@@ -441,8 +446,14 @@ class PostgresHook(DbApiHook):
         return PostgresHook._serialize_cell_ppg2(cell, conn)
 
     def get_iam_token(self, conn: Connection) -> tuple[str, str, int]:
+        """Get the IAM token from different identity providers."""
+        if conn.extra_dejson.get("azure_conn_id"):
+            return self.get_azure_iam_token(conn)
+        return self.get_aws_iam_token(conn)
+
+    def get_aws_iam_token(self, conn: Connection) -> tuple[str, str, int]:
         """
-        Get the IAM token.
+        Get the AWS IAM token.
 
         This uses AWSHook to retrieve a temporary password to connect to
         Postgres or Redshift. Port is required. If none is provided, the 
default
@@ -500,6 +511,36 @@ class PostgresHook(DbApiHook):
             token = rds_client.generate_db_auth_token(conn.host, port, 
conn.login)
         return cast("str", login), cast("str", token), port
 
+    def get_azure_iam_token(self, conn: Connection) -> tuple[str, str, int]:
+        """
+        Get the Azure IAM token.
+
+        This uses AzureBaseHook to retrieve an OAUTH token to connect to 
Postgres.
+        Scope for the OAuth token can be set in the config option 
``azure_oauth_scope`` under the section ``[postgres]``.
+        """
+        if TYPE_CHECKING:
+            from airflow.providers.microsoft.azure.hooks.base_azure import 
AzureBaseHook
+
+        azure_conn_id = conn.extra_dejson.get("azure_conn_id", "azure_default")
+        try:
+            azure_conn = Connection.get(azure_conn_id)
+        except AttributeError:
+            azure_conn = Connection.get_connection_from_secrets(azure_conn_id) 
 # type: ignore[attr-defined]
+        azure_base_hook: AzureBaseHook = azure_conn.get_hook()
+        scope = conf.get("postgres", "azure_oauth_scope", 
fallback=self.default_azure_oauth_scope)
+        try:
+            token = azure_base_hook.get_token(scope).token
+        except AttributeError as e:
+            if e.name == "get_token" and e.obj == azure_base_hook:
+                raise AttributeError(
+                    "'AzureBaseHook' object has no attribute 'get_token'. "
+                    "Please upgrade 
apache-airflow-providers-microsoft-azure>=12.8.0",
+                    name=e.name,
+                    obj=e.obj,
+                ) from e
+            raise
+        return cast("str", conn.login or azure_conn.login), token, conn.port 
or 5432
+
     def get_table_primary_key(self, table: str, schema: str | None = "public") 
-> list[str] | None:
         """
         Get the table's primary key.
diff --git a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py 
b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py
index 0607edc4405..47fffe0b6bb 100644
--- a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py
+++ b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py
@@ -444,6 +444,57 @@ class TestPostgresHookConn:
             port=(port or 5439),
         )
 
+    def test_get_conn_azure_iam(self, mocker, mock_connect):
+        mock_azure_conn_id = "azure_conn1"
+        mock_db_token = "azure_token1"
+        mock_conn_extra = {"iam": True, "azure_conn_id": mock_azure_conn_id}
+        self.connection.extra = json.dumps(mock_conn_extra)
+
+        mock_connection_class = 
mocker.patch("airflow.providers.postgres.hooks.postgres.Connection")
+        mock_azure_base_hook = 
mock_connection_class.get.return_value.get_hook.return_value
+        mock_azure_base_hook.get_token.return_value.token = mock_db_token
+
+        self.db_hook.get_conn()
+
+        # Check AzureBaseHook initialization and get_token call args
+        mock_connection_class.get.assert_called_once_with(mock_azure_conn_id)
+        
mock_azure_base_hook.get_token.assert_called_once_with(PostgresHook.default_azure_oauth_scope)
+
+        # Check expected psycopg2 connection call args
+        mock_connect.assert_called_once_with(
+            user=self.connection.login,
+            password=mock_db_token,
+            host=self.connection.host,
+            dbname=self.connection.schema,
+            port=(self.connection.port or 5432),
+        )
+
+        assert mock_db_token in self.db_hook.sqlalchemy_url
+
+    def test_get_azure_iam_token_expect_failure_on_get_token(self, mocker):
+        """Test get_azure_iam_token method gets token from provided connection 
id"""
+
+        class MockAzureBaseHookWithoutGetToken:
+            def __init__(self):
+                pass
+
+        azure_conn_id = "azure_test_conn"
+        mock_connection_class = 
mocker.patch("airflow.providers.postgres.hooks.postgres.Connection")
+        mock_connection_class.get.return_value.get_hook.return_value = 
MockAzureBaseHookWithoutGetToken()
+
+        self.connection.extra = json.dumps({"iam": True, "azure_conn_id": 
azure_conn_id})
+        with pytest.raises(
+            AttributeError,
+            match=(
+                "'AzureBaseHook' object has no attribute 'get_token'. "
+                "Please upgrade apache-airflow-providers-microsoft-azure>="
+            ),
+        ):
+            self.db_hook.get_azure_iam_token(self.connection)
+
+        # Check AzureBaseHook initialization
+        mock_connection_class.get.assert_called_once_with(azure_conn_id)
+
     def test_get_uri_from_connection_without_database_override(self, mocker):
         expected: str = f"postgresql{'+psycopg' if USE_PSYCOPG3 else 
''}://login:password@host:1/database"
         self.db_hook.get_connection = mocker.MagicMock(

Reply via email to