This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1c142ec5e64 Support optional scope in OAuth token request (#58871)
1c142ec5e64 is described below

commit 1c142ec5e641d5f31b437feddab5b11376f5de0a
Author: SameerMesiah97 <[email protected]>
AuthorDate: Mon Dec 1 22:43:24 2025 +0000

    Support optional scope in OAuth token request (#58871)
    
    Omit scope field when unset instead of defaulting to None, and add tests 
for scoped and unscoped OAuth requests
    
    Fix SnowflakeHook OAuth tests to correctly handle client_credentials flow
    
    Docs: Add ``scope`` parameter description for OAuth in SnowflakeHook
    
    Fixed tests to ensure client_credentials code path is tested instead of 
refresh token
    
    Fixed formatting issues and removed extraneous comments
    
    Co-authored-by: Sameer Mesiah <[email protected]>
---
 providers/snowflake/docs/connections/snowflake.rst |   1 +
 .../airflow/providers/snowflake/hooks/snowflake.py |   8 ++
 .../tests/unit/snowflake/hooks/test_snowflake.py   | 108 +++++++++++++++++++++
 3 files changed, 117 insertions(+)

diff --git a/providers/snowflake/docs/connections/snowflake.rst 
b/providers/snowflake/docs/connections/snowflake.rst
index 3a523d97b4e..900ed04a1ed 100644
--- a/providers/snowflake/docs/connections/snowflake.rst
+++ b/providers/snowflake/docs/connections/snowflake.rst
@@ -60,6 +60,7 @@ Extra (optional)
     * ``authenticator``: To connect using OAuth set this parameter ``oauth``.
     * ``token_endpoint``: Specify token endpoint for external OAuth provider.
     * ``grant_type``: Specify grant type for OAuth authentication. Currently 
supported: ``refresh_token`` (default), ``client_credentials``.
+    * ``scope``: Specify OAuth scope to include in the access token request 
for any OAuth grant type.
     * ``refresh_token``: Specify refresh_token for OAuth connection.
     * ``azure_conn_id``: Azure Connection ID to be used for retrieving the 
OAuth token using Azure Entra authentication. Login and Password fields aren't 
required when using this method. Scope for the Azure OAuth token can be set in 
the config option ``azure_oauth_scope`` under the section ``[snowflake]``. 
Requires `apache-airflow-providers-microsoft-azure>=12.8.0`.
     * ``private_key_file``: Specify the path to the private key file.
diff --git 
a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py 
b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
index a1aa025f726..2e0c24e419a 100644
--- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
+++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
@@ -142,6 +142,7 @@ class SnowflakeHook(DbApiHook):
                         "grant_type": "refresh_token client_credentials",
                         "token_endpoint": "token endpoint",
                         "refresh_token": "refresh token",
+                        "scope": "scope",
                     },
                     indent=1,
                 ),
@@ -223,6 +224,11 @@ class SnowflakeHook(DbApiHook):
             "redirect_uri": conn_config.get("redirect_uri", 
"https://localhost.com";),
         }
 
+        scope = conn_config.get("scope")
+
+        if scope:
+            data["scope"] = scope
+
         if grant_type == "refresh_token":
             data |= {
                 "refresh_token": conn_config["refresh_token"],
@@ -388,8 +394,10 @@ class SnowflakeHook(DbApiHook):
                 conn_config["token"] = 
self.get_azure_oauth_token(extra_dict["azure_conn_id"])
             else:
                 token_endpoint = self._get_field(extra_dict, "token_endpoint") 
or ""
+                conn_config["scope"] = self._get_field(extra_dict, "scope")
                 conn_config["client_id"] = conn.login
                 conn_config["client_secret"] = conn.password
+
                 conn_config["token"] = self.get_oauth_token(
                     conn_config=conn_config,
                     token_endpoint=token_endpoint,
diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py 
b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
index 163b8753801..4485c87c7a2 100644
--- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
+++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
@@ -704,6 +704,44 @@ class TestPytestSnowflakeHook:
         assert "region" in conn_params
         assert "account" in conn_params
 
+    @mock.patch("requests.post")
+    def test_get_conn_params_include_scope(self, mock_requests_post):
+        """
+        Verify that `_get_conn_params` includes the `scope` field when it is 
present
+        in the connection extras.
+        """
+        mock_requests_post.return_value = Mock(
+            status_code=200,
+            json=lambda: {
+                "access_token": "dummy",
+                "expires_in": 600,
+                "token_type": "Bearer",
+                "username": "test_user",
+            },
+        )
+
+        connection_kwargs = {
+            **BASE_CONNECTION_KWARGS,
+            "login": "test_client_id",
+            "password": "test_client_secret",
+            "extra": {
+                "account": "airflow",
+                "authenticator": "oauth",
+                "grant_type": "client_credentials",
+                "scope": "default",
+            },
+        }
+
+        with mock.patch.dict(
+            "os.environ",
+            {"AIRFLOW_CONN_TEST_CONN": 
Connection(**connection_kwargs).get_uri()},
+        ):
+            hook = SnowflakeHook(snowflake_conn_id="test_conn")
+            params = hook._get_conn_params
+        mock_requests_post.assert_called_once()
+        assert "scope" in params
+        assert params["scope"] == "default"
+
     def test_should_add_partner_info(self):
         with mock.patch.dict(
             "os.environ",
@@ -1133,3 +1171,73 @@ class TestPytestSnowflakeHook:
 
         # Check AzureBaseHook initialization
         mock_connection_class.get.assert_called_once_with(azure_conn_id)
+
+    @mock.patch("requests.post")
+    def test_get_oauth_token_with_scope(self, mock_requests_post):
+        """
+        Verify that `get_oauth_token` returns an access token and includes the
+        provided scope in the outgoing OAuth request payload.
+        """
+
+        mock_requests_post.return_value = Mock(
+            status_code=200,
+            json=lambda: {"access_token": "dummy_token"},
+        )
+
+        connection_kwargs = {
+            **BASE_CONNECTION_KWARGS,
+            "login": "client_id",
+            "password": "client_secret",
+            "extra": {
+                "account": "airflow",
+                "authenticator": "oauth",
+                "grant_type": "client_credentials",
+                "scope": "default",
+            },
+        }
+
+        with mock.patch.dict(
+            "os.environ",
+            {"AIRFLOW_CONN_TEST_CONN": 
Connection(**connection_kwargs).get_uri()},
+        ):
+            hook = SnowflakeHook(snowflake_conn_id="test_conn")
+            token = hook.get_oauth_token(grant_type="client_credentials")
+
+        assert token == "dummy_token"
+
+        called_data = mock_requests_post.call_args.kwargs["data"]
+
+        assert called_data["scope"] == "default"
+        assert called_data["grant_type"] == "client_credentials"
+
+    @mock.patch("requests.post")
+    def test_get_oauth_token_without_scope(self, mock_requests_post):
+        """
+        Verify that `get_oauth_token` returns an access token and sends 
`scope=None`
+        when no scope is defined in the connection extras.
+        """
+        mock_requests_post.return_value = Mock(
+            status_code=200,
+            json=lambda: {"access_token": "dummy_token"},
+        )
+
+        connection_kwargs = {
+            **BASE_CONNECTION_KWARGS,
+            "login": "client_id",
+            "password": "client_secret",
+            "extra": {"account": "airflow", "authenticator": "oauth", 
"grant_type": "client_credentials"},
+        }
+
+        with mock.patch.dict(
+            "os.environ",
+            {"AIRFLOW_CONN_TEST_CONN": 
Connection(**connection_kwargs).get_uri()},
+        ):
+            hook = SnowflakeHook(snowflake_conn_id="test_conn")
+            token = hook.get_oauth_token(grant_type="client_credentials")
+
+        assert token == "dummy_token"
+
+        called_data = mock_requests_post.call_args.kwargs["data"]
+
+        assert "scope" not in called_data
+        assert called_data["grant_type"] == "client_credentials"

Reply via email to