This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 6449262f4d9 Extend SnowflakeHook OAuth implementation to support
external IDPs and client_credentials grant (#51620)
6449262f4d9 is described below
commit 6449262f4d9de2f3a43967550a9ee08421aa6268
Author: Philippe Gagnon <[email protected]>
AuthorDate: Wed Jun 11 16:27:39 2025 -0400
Extend SnowflakeHook OAuth implementation to support external IDPs and
client_credentials grant (#51620)
This PR extends `SnowflakeHook`'s OAuth implementation, allowing users to
specify a `token_endpoint` extras parameter to support external IDPs (other
than Snowflake's built-in).
It also introduces the `grant_type` parameter in order to support grant
types other than `refresh_token`. This initial implementation provides support
for the `client_credentials` grant type, which can be provided through the
login/password fields to acquire an access token.
Backwards compatibility with the existing interface is preserved.
---
providers/snowflake/docs/connections/snowflake.rst | 2 +
.../airflow/providers/snowflake/hooks/snowflake.py | 39 ++++--
.../providers/snowflake/hooks/snowflake_sql_api.py | 11 +-
.../tests/unit/snowflake/hooks/test_snowflake.py | 139 ++++++++++++++++++++-
4 files changed, 180 insertions(+), 11 deletions(-)
diff --git a/providers/snowflake/docs/connections/snowflake.rst
b/providers/snowflake/docs/connections/snowflake.rst
index 9c8c6a4a875..e15280ceb77 100644
--- a/providers/snowflake/docs/connections/snowflake.rst
+++ b/providers/snowflake/docs/connections/snowflake.rst
@@ -58,6 +58,8 @@ Extra (optional)
* ``warehouse``: Snowflake warehouse name.
* ``role``: Snowflake role.
* ``authenticator``: To connect using OAuth set this parameter ``oauth``.
+ * ``token_endpoint``: Specify token endpoint for external OAuth provider.
+ * ``grant_type``: Specify grant type for OAuth authentication. Currently
supported: ``refresh_token`` (default), ``client_credentials``.
* ``refresh_token``: Specify refresh_token for OAuth connection.
* ``private_key_file``: Specify the path to the private key file.
* ``private_key_content``: Specify the content of the private key file in
base64 encoded format. You can use the following Python code to encode the
private key:
diff --git
a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
index 5535a5e3549..88cb5f331e2 100644
--- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
+++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py
@@ -136,6 +136,9 @@ class SnowflakeHook(DbApiHook):
"session_parameters": "session parameters",
"client_request_mfa_token": "client request mfa token",
"client_store_temporary_credential": "client store
temporary credential (externalbrowser mode)",
+ "grant_type": "refresh_token client_credentials",
+ "token_endpoint": "token endpoint",
+ "refresh_token": "refresh token",
},
indent=1,
),
@@ -200,18 +203,32 @@ class SnowflakeHook(DbApiHook):
return account_identifier
- def get_oauth_token(self, conn_config: dict | None = None) -> str:
+ def get_oauth_token(
+ self,
+ conn_config: dict | None = None,
+ token_endpoint: str | None = None,
+ grant_type: str = "refresh_token",
+ ) -> str:
"""Generate temporary OAuth access token using refresh token in
connection details."""
if conn_config is None:
conn_config = self._get_conn_params
- url =
f"https://{conn_config['account']}.snowflakecomputing.com/oauth/token-request"
+ url = token_endpoint or
f"https://{conn_config['account']}.snowflakecomputing.com/oauth/token-request"
data = {
- "grant_type": "refresh_token",
- "refresh_token": conn_config["refresh_token"],
+ "grant_type": grant_type,
"redirect_uri": conn_config.get("redirect_uri",
"https://localhost.com"),
}
+
+ if grant_type == "refresh_token":
+ data |= {
+ "refresh_token": conn_config["refresh_token"],
+ }
+ elif grant_type == "client_credentials":
+ pass # no setup necessary for client credentials grant.
+ else:
+ raise ValueError(f"Unknown grant_type: {grant_type}")
+
response = requests.post(
url,
data=data,
@@ -226,7 +243,8 @@ class SnowflakeHook(DbApiHook):
except requests.exceptions.HTTPError as e: # pragma: no cover
msg = f"Response: {e.response.content.decode()} Status Code:
{e.response.status_code}"
raise AirflowException(msg)
- return response.json()["access_token"]
+ token = response.json()["access_token"]
+ return token
@cached_property
def _get_conn_params(self) -> dict[str, str | None]:
@@ -329,14 +347,21 @@ class SnowflakeHook(DbApiHook):
if refresh_token:
conn_config["refresh_token"] = refresh_token
conn_config["authenticator"] = "oauth"
+
+ if conn_config.get("authenticator") == "oauth":
+ token_endpoint = self._get_field(extra_dict, "token_endpoint") or
""
conn_config["client_id"] = conn.login
conn_config["client_secret"] = conn.password
+ conn_config["token"] = self.get_oauth_token(
+ conn_config=conn_config,
+ token_endpoint=token_endpoint,
+ grant_type=extra_dict.get("grant_type", "refresh_token"),
+ )
+
conn_config.pop("login", None)
conn_config.pop("user", None)
conn_config.pop("password", None)
- conn_config["token"] =
self.get_oauth_token(conn_config=conn_config)
-
# configure custom target hostname and port, if specified
snowflake_host = extra_dict.get("host")
snowflake_port = extra_dict.get("port")
diff --git
a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py
b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py
index 0a2c9fd4247..7e4ef8dfe02 100644
---
a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py
+++
b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake_sql_api.py
@@ -223,14 +223,21 @@ class SnowflakeSqlApiHook(SnowflakeHook):
}
return headers
- def get_oauth_token(self, conn_config: dict[str, Any] | None = None) ->
str:
+ def get_oauth_token(
+ self,
+ conn_config: dict[str, Any] | None = None,
+ token_endpoint: str | None = None,
+ grant_type: str = "refresh_token",
+ ) -> str:
"""Generate temporary OAuth access token using refresh token in
connection details."""
warnings.warn(
"This method is deprecated. Please use `get_oauth_token` method
from `SnowflakeHook` instead. ",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
- return super().get_oauth_token(conn_config=conn_config)
+ return super().get_oauth_token(
+ conn_config=conn_config, token_endpoint=token_endpoint,
grant_type=grant_type
+ )
def get_request_url_header_params(self, query_id: str) -> tuple[dict[str,
Any], dict[str, Any], str]:
"""
diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
index ec384ddc5c8..e28d0bca28e 100644
--- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
+++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
@@ -53,14 +53,13 @@ BASE_CONNECTION_KWARGS: dict = {
},
}
-CONN_PARAMS_OAUTH = {
+CONN_PARAMS_OAUTH_BASE = {
"account": "airflow",
"application": "AIRFLOW",
"authenticator": "oauth",
"database": "db",
"client_id": "test_client_id",
"client_secret": "test_client_pw",
- "refresh_token": "secrettoken",
"region": "af_region",
"role": "af_role",
"schema": "public",
@@ -68,6 +67,8 @@ CONN_PARAMS_OAUTH = {
"warehouse": "af_wh",
}
+CONN_PARAMS_OAUTH = CONN_PARAMS_OAUTH_BASE | {"refresh_token": "secrettoken"}
+
@pytest.fixture
def unencrypted_temporary_private_key(tmp_path: Path) -> Path:
@@ -559,6 +560,112 @@ class TestPytestSnowflakeHook:
assert "region" in conn_params_extra_keys
assert "account" in conn_params_extra_keys
+ @mock.patch("requests.post")
+ @mock.patch(
+
"airflow.providers.snowflake.hooks.snowflake.SnowflakeHook._get_conn_params",
+ new_callable=PropertyMock,
+ )
+ def test_get_conn_params_should_support_oauth_with_token_endpoint(
+ self, mock_get_conn_params, requests_post
+ ):
+ requests_post.return_value = Mock(
+ status_code=200,
+ json=lambda: {
+ "access_token": "supersecretaccesstoken",
+ "expires_in": 600,
+ "refresh_token": "secrettoken",
+ "token_type": "Bearer",
+ "username": "test_user",
+ },
+ )
+ connection_kwargs = {
+ **BASE_CONNECTION_KWARGS,
+ "login": "test_client_id",
+ "password": "test_client_secret",
+ "extra": {
+ "database": "db",
+ "account": "airflow",
+ "warehouse": "af_wh",
+ "region": "af_region",
+ "role": "af_role",
+ "refresh_token": "secrettoken",
+ "authenticator": "oauth",
+ "token_endpoint": "https://www.example.com/oauth/token",
+ },
+ }
+ mock_get_conn_params.return_value = connection_kwargs
+ with mock.patch.dict("os.environ",
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
+ hook = SnowflakeHook(snowflake_conn_id="test_conn")
+ conn_params = hook._get_conn_params
+
+ conn_params_keys = conn_params.keys()
+ conn_params_extra = conn_params.get("extra", {})
+ conn_params_extra_keys = conn_params_extra.keys()
+
+ assert "authenticator" in conn_params_extra_keys
+ assert conn_params_extra["authenticator"] == "oauth"
+ assert conn_params_extra["token_endpoint"] ==
"https://www.example.com/oauth/token"
+
+ assert "user" not in conn_params_keys
+ assert "password" in conn_params_keys
+ assert "refresh_token" in conn_params_extra_keys
+ # Mandatory fields to generate account_identifier
`https://<account>.<region>`
+ assert "region" in conn_params_extra_keys
+ assert "account" in conn_params_extra_keys
+
+ @mock.patch("requests.post")
+ @mock.patch(
+
"airflow.providers.snowflake.hooks.snowflake.SnowflakeHook._get_conn_params",
+ new_callable=PropertyMock,
+ )
+ def test_get_conn_params_should_support_oauth_with_client_credentials(
+ self, mock_get_conn_params, requests_post
+ ):
+ requests_post.return_value = Mock(
+ status_code=200,
+ json=lambda: {
+ "access_token": "supersecretaccesstoken",
+ "expires_in": 600,
+ "refresh_token": "secrettoken",
+ "token_type": "Bearer",
+ "username": "test_user",
+ },
+ )
+ connection_kwargs = {
+ **BASE_CONNECTION_KWARGS,
+ "login": "test_client_id",
+ "password": "test_client_secret",
+ "extra": {
+ "database": "db",
+ "account": "airflow",
+ "warehouse": "af_wh",
+ "region": "af_region",
+ "role": "af_role",
+ "authenticator": "oauth",
+ "token_endpoint": "https://www.example.com/oauth/token",
+ "grant_type": "client_credentials",
+ },
+ }
+ mock_get_conn_params.return_value = connection_kwargs
+ with mock.patch.dict("os.environ",
AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
+ hook = SnowflakeHook(snowflake_conn_id="test_conn")
+ conn_params = hook._get_conn_params
+
+ conn_params_keys = conn_params.keys()
+ conn_params_extra = conn_params.get("extra", {})
+ conn_params_extra_keys = conn_params_extra.keys()
+
+ assert "authenticator" in conn_params_extra_keys
+ assert conn_params_extra["authenticator"] == "oauth"
+ assert conn_params_extra["grant_type"] == "client_credentials"
+
+ assert "user" not in conn_params_keys
+ assert "password" in conn_params_keys
+ assert "refresh_token" not in conn_params_extra_keys
+ # Mandatory fields to generate account_identifier
`https://<account>.<region>`
+ assert "region" in conn_params_extra_keys
+ assert "account" in conn_params_extra_keys
+
def test_should_add_partner_info(self):
with mock.patch.dict(
"os.environ",
@@ -917,3 +1024,31 @@ class TestPytestSnowflakeHook:
headers={"Content-Type": "application/x-www-form-urlencoded"},
auth=basic_auth,
)
+
+ @mock.patch("airflow.providers.snowflake.hooks.snowflake.HTTPBasicAuth")
+ @mock.patch("requests.post")
+ @mock.patch(
+
"airflow.providers.snowflake.hooks.snowflake.SnowflakeHook._get_conn_params",
+ new_callable=PropertyMock,
+ )
+ def test_get_oauth_token_with_token_endpoint(self, mock_conn_param,
requests_post, mock_auth):
+ """Test get_oauth_token method makes the right http request"""
+ basic_auth = {"Authorization": "Basic usernamepassword"}
+ token_endpoint = "https://example.com/oauth/token"
+ mock_conn_param.return_value = CONN_PARAMS_OAUTH
+ requests_post.return_value.status_code = 200
+ mock_auth.return_value = basic_auth
+
+ hook = SnowflakeHook(snowflake_conn_id="mock_conn_id")
+ hook.get_oauth_token(conn_config=CONN_PARAMS_OAUTH,
token_endpoint=token_endpoint)
+
+ requests_post.assert_called_once_with(
+ token_endpoint,
+ data={
+ "grant_type": "refresh_token",
+ "refresh_token": CONN_PARAMS_OAUTH["refresh_token"],
+ "redirect_uri": "https://localhost.com",
+ },
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
+ auth=basic_auth,
+ )