This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6449262f4d9 Extend SnowflakeHook OAuth implementation to support 
external IDPs and client_credentials grant (#51620)
6449262f4d9 is described below

commit 6449262f4d9de2f3a43967550a9ee08421aa6268
Author: Philippe Gagnon <[email protected]>
AuthorDate: Wed Jun 11 16:27:39 2025 -0400

    Extend SnowflakeHook OAuth implementation to support external IDPs and 
client_credentials grant (#51620)
    
    This PR extends `SnowflakeHook`'s OAuth implementation, allowing users to 
specify a `token_endpoint` extras parameter to support external IDPs (other 
than Snowflake's built-in).
    
    It also introduces the `grant_type` parameter in order to support grant 
types other than `refresh_token`. This initial implementation provides support 
for the `client_credentials` grant type, which can be provided through the 
login/password fields to acquire an access token.
    
    Backwards compatibility with the existing interface is preserved.
---
 providers/snowflake/docs/connections/snowflake.rst |   2 +
 .../airflow/providers/snowflake/hooks/snowflake.py |  39 ++++--
 .../providers/snowflake/hooks/snowflake_sql_api.py |  11 +-
 .../tests/unit/snowflake/hooks/test_snowflake.py   | 139 ++++++++++++++++++++-
 4 files changed, 180 insertions(+), 11 deletions(-)

diff --git a/providers/snowflake/docs/connections/snowflake.rst 
b/providers/snowflake/docs/connections/snowflake.rst
index 9c8c6a4a875..e15280ceb77 100644
--- a/providers/snowflake/docs/connections/snowflake.rst
+++ b/providers/snowflake/docs/connections/snowflake.rst
@@ -58,6 +58,8 @@ Extra (optional)
     * ``warehouse``: Snowflake warehouse name.
     * ``role``: Snowflake role.
     * ``authenticator``: To connect using OAuth set this parameter ``oauth``.
+    * ``token_endpoint``: Specify token endpoint for external OAuth provider.
+    * ``grant_type``: Specify grant type for OAuth authentication. Currently 
supported: ``refresh_token`` (default), ``client_credentials``.
     * ``refresh_token``: Specify refresh_token for OAuth connection.
     * ``private_key_file``: Specify the path to the private key file.
     * ``private_key_content``: Specify the content of the private key file in 
base64 encoded format. You can use the following Python code to encode the 
private key:
diff --git 
a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py 
b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
index 5535a5e3549..88cb5f331e2 100644
--- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
+++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
@@ -136,6 +136,9 @@ class SnowflakeHook(DbApiHook):
                         "session_parameters": "session parameters",
                         "client_request_mfa_token": "client request mfa token",
                         "client_store_temporary_credential": "client store 
temporary credential (externalbrowser mode)",
+                        "grant_type": "refresh_token client_credentials",
+                        "token_endpoint": "token endpoint",
+                        "refresh_token": "refresh token",
                     },
                     indent=1,
                 ),
@@ -200,18 +203,32 @@ class SnowflakeHook(DbApiHook):
 
         return account_identifier
 
-    def get_oauth_token(self, conn_config: dict | None = None) -> str:
+    def get_oauth_token(
+        self,
+        conn_config: dict | None = None,
+        token_endpoint: str | None = None,
+        grant_type: str = "refresh_token",
+    ) -> str:
         """Generate temporary OAuth access token using refresh token in 
connection details."""
         if conn_config is None:
             conn_config = self._get_conn_params
 
-        url = 
f"https://{conn_config['account']}.snowflakecomputing.com/oauth/token-request"
+        url = token_endpoint or 
f"https://{conn_config['account']}.snowflakecomputing.com/oauth/token-request"
 
         data = {
-            "grant_type": "refresh_token",
-            "refresh_token": conn_config["refresh_token"],
+            "grant_type": grant_type,
             "redirect_uri": conn_config.get("redirect_uri", 
"https://localhost.com";),
         }
+
+        if grant_type == "refresh_token":
+            data |= {
+                "refresh_token": conn_config["refresh_token"],
+            }
+        elif grant_type == "client_credentials":
+            pass  # no setup necessary for client credentials grant.
+        else:
+            raise ValueError(f"Unknown grant_type: {grant_type}")
+
         response = requests.post(
             url,
             data=data,
@@ -226,7 +243,8 @@ class SnowflakeHook(DbApiHook):
         except requests.exceptions.HTTPError as e:  # pragma: no cover
             msg = f"Response: {e.response.content.decode()} Status Code: 
{e.response.status_code}"
             raise AirflowException(msg)
-        return response.json()["access_token"]
+        token = response.json()["access_token"]
+        return token
 
     @cached_property
     def _get_conn_params(self) -> dict[str, str | None]:
@@ -329,14 +347,21 @@ class SnowflakeHook(DbApiHook):
         if refresh_token:
             conn_config["refresh_token"] = refresh_token
             conn_config["authenticator"] = "oauth"
+
+        if conn_config.get("authenticator") == "oauth":
+            token_endpoint = self._get_field(extra_dict, "token_endpoint") or 
""
             conn_config["client_id"] = conn.login
             conn_config["client_secret"] = conn.password
+            conn_config["token"] = self.get_oauth_token(
+                conn_config=conn_config,
+                token_endpoint=token_endpoint,
+                grant_type=extra_dict.get("grant_type", "refresh_token"),
+            )
+
             conn_config.pop("login", None)
             conn_config.pop("user", None)
             conn_config.pop("password", None)
 
-            conn_config["token"] = 
self.get_oauth_token(conn_config=conn_config)
-
         # configure custom target hostname and port, if specified
         snowflake_host = extra_dict.get("host")
         snowflake_port = extra_dict.get("port")
diff --git 
a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py
 
b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py
index 0a2c9fd4247..7e4ef8dfe02 100644
--- 
a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py
+++ 
b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py
@@ -223,14 +223,21 @@ class SnowflakeSqlApiHook(SnowflakeHook):
         }
         return headers
 
-    def get_oauth_token(self, conn_config: dict[str, Any] | None = None) -> 
str:
+    def get_oauth_token(
+        self,
+        conn_config: dict[str, Any] | None = None,
+        token_endpoint: str | None = None,
+        grant_type: str = "refresh_token",
+    ) -> str:
         """Generate temporary OAuth access token using refresh token in 
connection details."""
         warnings.warn(
             "This method is deprecated. Please use `get_oauth_token` method 
from `SnowflakeHook` instead. ",
             AirflowProviderDeprecationWarning,
             stacklevel=2,
         )
-        return super().get_oauth_token(conn_config=conn_config)
+        return super().get_oauth_token(
+            conn_config=conn_config, token_endpoint=token_endpoint, 
grant_type=grant_type
+        )
 
     def get_request_url_header_params(self, query_id: str) -> tuple[dict[str, 
Any], dict[str, Any], str]:
         """
diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py 
b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
index ec384ddc5c8..e28d0bca28e 100644
--- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
+++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
@@ -53,14 +53,13 @@ BASE_CONNECTION_KWARGS: dict = {
     },
 }
 
-CONN_PARAMS_OAUTH = {
+CONN_PARAMS_OAUTH_BASE = {
     "account": "airflow",
     "application": "AIRFLOW",
     "authenticator": "oauth",
     "database": "db",
     "client_id": "test_client_id",
     "client_secret": "test_client_pw",
-    "refresh_token": "secrettoken",
     "region": "af_region",
     "role": "af_role",
     "schema": "public",
@@ -68,6 +67,8 @@ CONN_PARAMS_OAUTH = {
     "warehouse": "af_wh",
 }
 
+CONN_PARAMS_OAUTH = CONN_PARAMS_OAUTH_BASE | {"refresh_token": "secrettoken"}
+
 
 @pytest.fixture
 def unencrypted_temporary_private_key(tmp_path: Path) -> Path:
@@ -559,6 +560,112 @@ class TestPytestSnowflakeHook:
         assert "region" in conn_params_extra_keys
         assert "account" in conn_params_extra_keys
 
+    @mock.patch("requests.post")
+    @mock.patch(
+        
"airflow.providers.snowflake.hooks.snowflake.SnowflakeHook._get_conn_params",
+        new_callable=PropertyMock,
+    )
+    def test_get_conn_params_should_support_oauth_with_token_endpoint(
+        self, mock_get_conn_params, requests_post
+    ):
+        requests_post.return_value = Mock(
+            status_code=200,
+            json=lambda: {
+                "access_token": "supersecretaccesstoken",
+                "expires_in": 600,
+                "refresh_token": "secrettoken",
+                "token_type": "Bearer",
+                "username": "test_user",
+            },
+        )
+        connection_kwargs = {
+            **BASE_CONNECTION_KWARGS,
+            "login": "test_client_id",
+            "password": "test_client_secret",
+            "extra": {
+                "database": "db",
+                "account": "airflow",
+                "warehouse": "af_wh",
+                "region": "af_region",
+                "role": "af_role",
+                "refresh_token": "secrettoken",
+                "authenticator": "oauth",
+                "token_endpoint": "https://www.example.com/oauth/token";,
+            },
+        }
+        mock_get_conn_params.return_value = connection_kwargs
+        with mock.patch.dict("os.environ", 
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
+            hook = SnowflakeHook(snowflake_conn_id="test_conn")
+            conn_params = hook._get_conn_params
+
+        conn_params_keys = conn_params.keys()
+        conn_params_extra = conn_params.get("extra", {})
+        conn_params_extra_keys = conn_params_extra.keys()
+
+        assert "authenticator" in conn_params_extra_keys
+        assert conn_params_extra["authenticator"] == "oauth"
+        assert conn_params_extra["token_endpoint"] == 
"https://www.example.com/oauth/token";
+
+        assert "user" not in conn_params_keys
+        assert "password" in conn_params_keys
+        assert "refresh_token" in conn_params_extra_keys
+        # Mandatory fields to generate account_identifier 
`https://<account>.<region>`
+        assert "region" in conn_params_extra_keys
+        assert "account" in conn_params_extra_keys
+
+    @mock.patch("requests.post")
+    @mock.patch(
+        
"airflow.providers.snowflake.hooks.snowflake.SnowflakeHook._get_conn_params",
+        new_callable=PropertyMock,
+    )
+    def test_get_conn_params_should_support_oauth_with_client_credentials(
+        self, mock_get_conn_params, requests_post
+    ):
+        requests_post.return_value = Mock(
+            status_code=200,
+            json=lambda: {
+                "access_token": "supersecretaccesstoken",
+                "expires_in": 600,
+                "refresh_token": "secrettoken",
+                "token_type": "Bearer",
+                "username": "test_user",
+            },
+        )
+        connection_kwargs = {
+            **BASE_CONNECTION_KWARGS,
+            "login": "test_client_id",
+            "password": "test_client_secret",
+            "extra": {
+                "database": "db",
+                "account": "airflow",
+                "warehouse": "af_wh",
+                "region": "af_region",
+                "role": "af_role",
+                "authenticator": "oauth",
+                "token_endpoint": "https://www.example.com/oauth/token";,
+                "grant_type": "client_credentials",
+            },
+        }
+        mock_get_conn_params.return_value = connection_kwargs
+        with mock.patch.dict("os.environ", 
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
+            hook = SnowflakeHook(snowflake_conn_id="test_conn")
+            conn_params = hook._get_conn_params
+
+        conn_params_keys = conn_params.keys()
+        conn_params_extra = conn_params.get("extra", {})
+        conn_params_extra_keys = conn_params_extra.keys()
+
+        assert "authenticator" in conn_params_extra_keys
+        assert conn_params_extra["authenticator"] == "oauth"
+        assert conn_params_extra["grant_type"] == "client_credentials"
+
+        assert "user" not in conn_params_keys
+        assert "password" in conn_params_keys
+        assert "refresh_token" not in conn_params_extra_keys
+        # Mandatory fields to generate account_identifier 
`https://<account>.<region>`
+        assert "region" in conn_params_extra_keys
+        assert "account" in conn_params_extra_keys
+
     def test_should_add_partner_info(self):
         with mock.patch.dict(
             "os.environ",
@@ -917,3 +1024,31 @@ class TestPytestSnowflakeHook:
             headers={"Content-Type": "application/x-www-form-urlencoded"},
             auth=basic_auth,
         )
+
+    @mock.patch("airflow.providers.snowflake.hooks.snowflake.HTTPBasicAuth")
+    @mock.patch("requests.post")
+    @mock.patch(
+        
"airflow.providers.snowflake.hooks.snowflake.SnowflakeHook._get_conn_params",
+        new_callable=PropertyMock,
+    )
+    def test_get_oauth_token_with_token_endpoint(self, mock_conn_param, 
requests_post, mock_auth):
+        """Test get_oauth_token method makes the right http request"""
+        basic_auth = {"Authorization": "Basic usernamepassword"}
+        token_endpoint = "https://example.com/oauth/token";
+        mock_conn_param.return_value = CONN_PARAMS_OAUTH
+        requests_post.return_value.status_code = 200
+        mock_auth.return_value = basic_auth
+
+        hook = SnowflakeHook(snowflake_conn_id="mock_conn_id")
+        hook.get_oauth_token(conn_config=CONN_PARAMS_OAUTH, 
token_endpoint=token_endpoint)
+
+        requests_post.assert_called_once_with(
+            token_endpoint,
+            data={
+                "grant_type": "refresh_token",
+                "refresh_token": CONN_PARAMS_OAUTH["refresh_token"],
+                "redirect_uri": "https://localhost.com";,
+            },
+            headers={"Content-Type": "application/x-www-form-urlencoded"},
+            auth=basic_auth,
+        )

Reply via email to