This is an automated email from the ASF dual-hosted git repository.
vavila pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new fa346099523 feat: Support OAuth2 single-use refresh tokens (#38364)
fa346099523 is described below
commit fa3460995239786ca96351ba5f03c964f894c698
Author: Vitor Avila <[email protected]>
AuthorDate: Tue Mar 3 16:07:15 2026 -0300
feat: Support OAuth2 single-use refresh tokens (#38364)
---
superset/db_engine_specs/base.py | 6 ++
superset/utils/oauth2.py | 4 ++
tests/unit_tests/db_engine_specs/test_base.py | 91 +++++++++++++++++++++++++++
tests/unit_tests/utils/oauth2_tests.py | 56 +++++++++++++++++
4 files changed, 157 insertions(+)
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index fb0e26e77e7..965eec46b12 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -572,6 +572,10 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
oauth2_token_request_uri: str | None = None
oauth2_token_request_type = "data" # noqa: S105
+ # Driver-specific query params to be included in
`get_oauth2_authorization_uri`
+ oauth2_additional_auth_uri_query_params: dict[str, Any] = {}
+ # Driver-specific params to be included in the `get_oauth2_token` request
body
+ oauth2_additional_token_request_params: dict[str, Any] = {}
# Driver-specific exception that should be mapped to OAuth2RedirectError
oauth2_exception = OAuth2RedirectError
@@ -754,6 +758,7 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
"state": encode_oauth2_state(state),
"redirect_uri": config["redirect_uri"],
"client_id": config["id"],
+ **cls.oauth2_additional_auth_uri_query_params,
}
# Add PKCE parameters (RFC 7636) if code_verifier is provided
@@ -784,6 +789,7 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
"client_secret": config["secret"],
"redirect_uri": config["redirect_uri"],
"grant_type": "authorization_code",
+ **cls.oauth2_additional_token_request_params,
}
# Add PKCE code_verifier if present (RFC 7636)
if code_verifier:
diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py
index 57cc0a25ce9..4978c0af5c5 100644
--- a/superset/utils/oauth2.py
+++ b/superset/utils/oauth2.py
@@ -167,6 +167,10 @@ def refresh_oauth2_token(
token.access_token_expiration = datetime.now() + timedelta(
seconds=token_response["expires_in"]
)
+ # Support single-use refresh tokens
+ if new_refresh_token := token_response.get("refresh_token"):
+ token.refresh_token = new_refresh_token
+
db.session.add(token)
return token.access_token
diff --git a/tests/unit_tests/db_engine_specs/test_base.py
b/tests/unit_tests/db_engine_specs/test_base.py
index 6c6b98e0593..14fa82b55af 100644
--- a/tests/unit_tests/db_engine_specs/test_base.py
+++ b/tests/unit_tests/db_engine_specs/test_base.py
@@ -1052,6 +1052,97 @@ def test_get_oauth2_token_with_pkce(mocker:
MockerFixture) -> None:
assert request_body["code_verifier"] == code_verifier
+def test_get_oauth2_authorization_uri_additional_params(
+ mocker: MockerFixture,
+) -> None:
+ """
+ Test that a subclass can inject additional query params into the
authorization URI
+ via `oauth2_additional_auth_uri_query_params`.
+ """
+ from superset.db_engine_specs.base import BaseEngineSpec
+
+ class CustomEngineSpec(BaseEngineSpec):
+ oauth2_additional_auth_uri_query_params = {
+ "prompt": "consent",
+ "access_type": "offline",
+ }
+
+ config: OAuth2ClientConfig = {
+ "id": "client-id",
+ "secret": "client-secret",
+ "scope": "read write",
+ "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
+ "authorization_request_uri": "https://oauth.example.com/authorize",
+ "token_request_uri": "https://oauth.example.com/token",
+ "request_content_type": "json",
+ }
+
+ state: OAuth2State = {
+ "database_id": 1,
+ "user_id": 1,
+ "default_redirect_uri": "http://localhost:8088/api/v1/oauth2/",
+ "tab_id": "1234",
+ }
+
+ url = CustomEngineSpec.get_oauth2_authorization_uri(config, state)
+ parsed = urlparse(url)
+ query = parse_qs(parsed.query)
+
+ # Standard params still present
+ assert query["response_type"][0] == "code"
+ assert query["client_id"][0] == "client-id"
+
+ # Additional params included
+ assert query["prompt"][0] == "consent"
+ assert query["access_type"][0] == "offline"
+
+
+def test_get_oauth2_token_additional_params(mocker: MockerFixture) -> None:
+ """
+ Test that a subclass can inject additional params into the token request
body
+ via `oauth2_additional_token_request_params`.
+ """
+ from superset.db_engine_specs.base import BaseEngineSpec
+
+ class CustomEngineSpec(BaseEngineSpec):
+ oauth2_additional_token_request_params = {
+ "audience": "https://api.example.com",
+ }
+
+ mocker.patch(
+ "flask.current_app.config",
+ {"DATABASE_OAUTH2_TIMEOUT": mocker.MagicMock(total_seconds=lambda:
30)},
+ )
+ mock_post = mocker.patch("superset.db_engine_specs.base.requests.post")
+ mock_post.return_value.json.return_value = {
+ "access_token": "test-access-token", # noqa: S105
+ "expires_in": 3600,
+ }
+
+ config: OAuth2ClientConfig = {
+ "id": "client-id",
+ "secret": "client-secret",
+ "scope": "read write",
+ "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
+ "authorization_request_uri": "https://oauth.example.com/authorize",
+ "token_request_uri": "https://oauth.example.com/token",
+ "request_content_type": "json",
+ }
+
+ result = CustomEngineSpec.get_oauth2_token(config, "auth-code")
+
+ assert result["access_token"] == "test-access-token" # noqa: S105
+ call_kwargs = mock_post.call_args
+ request_body = call_kwargs.kwargs.get("json") or
call_kwargs.kwargs.get("data")
+
+ # Standard params still present
+ assert request_body["grant_type"] == "authorization_code"
+ assert request_body["client_id"] == "client-id"
+
+ # Additional param included
+ assert request_body["audience"] == "https://api.example.com"
+
+
def test_start_oauth2_dance_uses_config_redirect_uri(mocker: MockerFixture) ->
None:
"""
Test that start_oauth2_dance uses DATABASE_OAUTH2_REDIRECT_URI config if
set.
diff --git a/tests/unit_tests/utils/oauth2_tests.py
b/tests/unit_tests/utils/oauth2_tests.py
index 08b7cc9c6e7..f04ae26e7c2 100644
--- a/tests/unit_tests/utils/oauth2_tests.py
+++ b/tests/unit_tests/utils/oauth2_tests.py
@@ -188,6 +188,62 @@ def test_refresh_oauth2_token_no_access_token_in_response(
assert result is None
+def test_refresh_oauth2_token_updates_refresh_token(
+ mocker: MockerFixture,
+) -> None:
+ """
+ Test that refresh_oauth2_token updates the refresh token when a new one is
returned.
+
+ Some OAuth2 providers issue single-use refresh tokens, where each token
refresh
+ response includes a new refresh token that replaces the previous one.
+ """
+ db = mocker.patch("superset.utils.oauth2.db")
+ mocker.patch("superset.utils.oauth2.DistributedLock")
+ db_engine_spec = mocker.MagicMock()
+ db_engine_spec.get_oauth2_fresh_token.return_value = {
+ "access_token": "new-access-token",
+ "expires_in": 3600,
+ "refresh_token": "new-refresh-token",
+ }
+ token = mocker.MagicMock()
+ token.refresh_token = "old-refresh-token" # noqa: S105
+
+ with freeze_time("2024-01-01"):
+ refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
+
+ assert token.access_token == "new-access-token" # noqa: S105
+ assert token.access_token_expiration == datetime(2024, 1, 1, 1)
+ assert token.refresh_token == "new-refresh-token" # noqa: S105
+ db.session.add.assert_called_with(token)
+
+
+def test_refresh_oauth2_token_keeps_refresh_token(
+ mocker: MockerFixture,
+) -> None:
+ """
+ Test that refresh_oauth2_token keeps the existing refresh token when none
returned.
+
+ When the OAuth2 provider does not issue a new refresh token in the
response,
+ the original refresh token should be preserved.
+ """
+ db = mocker.patch("superset.utils.oauth2.db")
+ mocker.patch("superset.utils.oauth2.DistributedLock")
+ db_engine_spec = mocker.MagicMock()
+ db_engine_spec.get_oauth2_fresh_token.return_value = {
+ "access_token": "new-access-token",
+ "expires_in": 3600,
+ }
+ token = mocker.MagicMock()
+ token.refresh_token = "original-refresh-token" # noqa: S105
+
+ with freeze_time("2024-01-01"):
+ refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
+
+ assert token.access_token == "new-access-token" # noqa: S105
+ assert token.refresh_token == "original-refresh-token" # noqa: S105
+ db.session.add.assert_called_with(token)
+
+
def test_generate_code_verifier_length() -> None:
"""
Test that generate_code_verifier produces a string of valid length (RFC
7636).