This is an automated email from the ASF dual-hosted git repository. beto pushed a commit to branch support-pkce in repository https://gitbox.apache.org/repos/asf/superset.git
commit 20211f29b2fdbb6a44b9c19fb71991adc2153944 Author: Beto Dealmeida <[email protected]> AuthorDate: Mon Jan 12 17:22:30 2026 -0500 feat: support PKCE in OAuth2 flow --- superset/commands/database/oauth2.py | 3 + superset/db_engine_specs/base.py | 39 +++++++- superset/superset_typing.py | 4 +- superset/utils/oauth2.py | 58 +++++++++-- tests/unit_tests/db_engine_specs/test_base.py | 120 +++++++++++++++++++++++ tests/unit_tests/utils/oauth2_tests.py | 135 +++++++++++++++++++++++++- 6 files changed, 345 insertions(+), 14 deletions(-) diff --git a/superset/commands/database/oauth2.py b/superset/commands/database/oauth2.py index f7259077bc..71908b3caa 100644 --- a/superset/commands/database/oauth2.py +++ b/superset/commands/database/oauth2.py @@ -50,9 +50,12 @@ class OAuth2StoreTokenCommand(BaseCommand): if oauth2_config is None: raise OAuth2Error("No configuration found for OAuth2") + # Pass PKCE code_verifier if present in state (RFC 7636) + code_verifier = self._state.get("code_verifier") token_response = self._database.db_engine_spec.get_oauth2_token( oauth2_config, self._parameters["code"], + code_verifier=code_verifier, ) # delete old tokens diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 6c0cd77478..430e0bee9a 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -83,7 +83,11 @@ from superset.utils.core import ColumnSpec, GenericDataType, QuerySource from superset.utils.hashing import hash_from_str from superset.utils.json import redact_sensitive, reveal_sensitive from superset.utils.network import is_hostname_valid, is_port_open -from superset.utils.oauth2 import encode_oauth2_state +from superset.utils.oauth2 import ( + encode_oauth2_state, + generate_code_challenge, + generate_code_verifier, +) if TYPE_CHECKING: from superset.connectors.sqla.models import TableColumn @@ -474,10 +478,17 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods tab sends a message to the original tab informing that authorization was successful (or not), and then closes. The original tab will automatically re-run the query after authorization. + + PKCE (RFC 7636) is used to protect against authorization code interception + attacks. A code_verifier is generated and stored in the state, while the + code_challenge (derived from the verifier) is sent to the authorization server. """ tab_id = str(uuid4()) default_redirect_uri = url_for("DatabaseRestApi.oauth2", _external=True) + # Generate PKCE code verifier (RFC 7636) + code_verifier = generate_code_verifier() + # The state is passed to the OAuth2 provider, and sent back to Superset after # the user authorizes the access. The redirect endpoint in Superset can then # inspect the state to figure out to which user/database the access token @@ -499,6 +510,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # UUID to the original tab, and the second tab will use it when sending the # message. "tab_id": tab_id, + # PKCE code verifier stored in state to be retrieved during token exchange + "code_verifier": code_verifier, } oauth2_config = database.get_oauth2_config() if oauth2_config is None: @@ -552,17 +565,24 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods """ Return URI for initial OAuth2 request. - Uses standard OAuth 2.0 parameters only. Subclasses can override - to add provider-specific parameters (e.g., Google's prompt=consent). + Uses standard OAuth 2.0 parameters plus PKCE (RFC 7636) parameters. + Subclasses can override to add provider-specific parameters + (e.g., Google's prompt=consent). """ uri = config["authorization_request_uri"] - params = { + params: dict[str, str] = { "scope": config["scope"], "response_type": "code", "state": encode_oauth2_state(state), "redirect_uri": config["redirect_uri"], "client_id": config["id"], } + + # Add PKCE parameters (RFC 7636) if code_verifier is present in state + if "code_verifier" in state: + params["code_challenge"] = generate_code_challenge(state["code_verifier"]) + params["code_challenge_method"] = "S256" + return urljoin(uri, "?" + urlencode(params)) @classmethod @@ -570,19 +590,28 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods cls, config: OAuth2ClientConfig, code: str, + code_verifier: str | None = None, ) -> OAuth2TokenResponse: """ Exchange authorization code for refresh/access tokens. + + If code_verifier is provided (PKCE flow), it will be included in the + token request per RFC 7636. """ timeout = app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds() uri = config["token_request_uri"] - req_body = { + req_body: dict[str, str] = { "code": code, "client_id": config["id"], "client_secret": config["secret"], "redirect_uri": config["redirect_uri"], "grant_type": "authorization_code", } + + # Add PKCE code_verifier if present (RFC 7636) + if code_verifier: + req_body["code_verifier"] = code_verifier + if config["request_content_type"] == "data": return requests.post(uri, data=req_body, timeout=timeout).json() return requests.post(uri, json=req_body, timeout=timeout).json() diff --git a/superset/superset_typing.py b/superset/superset_typing.py index 105a28d4cf..a1d36e811f 100644 --- a/superset/superset_typing.py +++ b/superset/superset_typing.py @@ -354,7 +354,7 @@ class OAuth2TokenResponse(TypedDict, total=False): refresh_token: str -class OAuth2State(TypedDict): +class OAuth2State(TypedDict, total=False): """ Type for the state passed during OAuth2. """ @@ -363,3 +363,5 @@ class OAuth2State(TypedDict): user_id: int default_redirect_uri: str tab_id: str + # PKCE code verifier (RFC 7636) - stored in state during token exchange + code_verifier: str diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py index ebe1f4012e..02fba6ac0a 100644 --- a/superset/utils/oauth2.py +++ b/superset/utils/oauth2.py @@ -17,6 +17,9 @@ from __future__ import annotations +import base64 +import hashlib +import secrets from contextlib import contextmanager from datetime import datetime, timedelta, timezone from typing import Any, Iterator, TYPE_CHECKING @@ -37,6 +40,37 @@ if TYPE_CHECKING: JWT_EXPIRATION = timedelta(minutes=5) +# PKCE code verifier length (RFC 7636 recommends 43-128 characters) +PKCE_CODE_VERIFIER_LENGTH = 64 + + +def generate_code_verifier() -> str: + """ + Generate a PKCE code verifier (RFC 7636). + + The code verifier is a high-entropy cryptographic random string using + unreserved characters [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~", + with a minimum length of 43 characters and a maximum length of 128. + """ + # Generate random bytes and encode as URL-safe base64 + random_bytes = secrets.token_bytes(PKCE_CODE_VERIFIER_LENGTH) + # Use URL-safe base64 encoding without padding + code_verifier = base64.urlsafe_b64encode(random_bytes).rstrip(b"=").decode("ascii") + return code_verifier + + +def generate_code_challenge(code_verifier: str) -> str: + """ + Generate a PKCE code challenge from a code verifier (RFC 7636). + + Uses the S256 method: BASE64URL(SHA256(code_verifier)) + """ + # Compute SHA-256 hash of the code verifier + digest = hashlib.sha256(code_verifier.encode("ascii")).digest() + # Encode as URL-safe base64 without padding + code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") + return code_challenge + @backoff.on_exception( backoff.expo, @@ -119,13 +153,17 @@ def encode_oauth2_state(state: OAuth2State) -> str: """ Encode the OAuth2 state. """ - payload = { + payload: dict[str, Any] = { "exp": datetime.now(tz=timezone.utc) + JWT_EXPIRATION, "database_id": state["database_id"], "user_id": state["user_id"], "default_redirect_uri": state["default_redirect_uri"], "tab_id": state["tab_id"], } + # Include PKCE code_verifier if present (RFC 7636) + if "code_verifier" in state: + payload["code_verifier"] = state["code_verifier"] + encoded_state = jwt.encode( payload=payload, key=app.config["SECRET_KEY"], @@ -143,6 +181,8 @@ class OAuth2StateSchema(Schema): user_id = fields.Int(required=True) default_redirect_uri = fields.Str(required=True) tab_id = fields.Str(required=True) + # PKCE code verifier (RFC 7636) - optional for backward compatibility + code_verifier = fields.Str(required=False, load_default=None) # pylint: disable=unused-argument @post_load @@ -151,12 +191,16 @@ class OAuth2StateSchema(Schema): data: dict[str, Any], **kwargs: Any, ) -> OAuth2State: - return OAuth2State( - database_id=data["database_id"], - user_id=data["user_id"], - default_redirect_uri=data["default_redirect_uri"], - tab_id=data["tab_id"], - ) + state: OAuth2State = { + "database_id": data["database_id"], + "user_id": data["user_id"], + "default_redirect_uri": data["default_redirect_uri"], + "tab_id": data["tab_id"], + } + # Include code_verifier if present (PKCE) + if data.get("code_verifier"): + state["code_verifier"] = data["code_verifier"] + return state class Meta: # pylint: disable=too-few-public-methods # ignore `exp` diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index bff4c93117..7a4b8f9206 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -943,3 +943,123 @@ def test_get_oauth2_authorization_uri_standard_params(mocker: MockerFixture) -> assert "prompt" not in query assert "access_type" not in query assert "include_granted_scopes" not in query + + # Verify PKCE parameters are NOT included when code_verifier is not in state + assert "code_challenge" not in query + assert "code_challenge_method" not in query + + +def test_get_oauth2_authorization_uri_with_pkce(mocker: MockerFixture) -> None: + """ + Test that BaseEngineSpec.get_oauth2_authorization_uri includes PKCE parameters + when code_verifier is present in state (RFC 7636). + """ + from urllib.parse import parse_qs, urlparse + + from superset.db_engine_specs.base import BaseEngineSpec + from superset.superset_typing import OAuth2ClientConfig, OAuth2State + from superset.utils.oauth2 import generate_code_challenge, generate_code_verifier + + config: OAuth2ClientConfig = { + "id": "client-id", + "secret": "client-secret", + "scope": "read write", + "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/", + "authorization_request_uri": "https://oauth.example.com/authorize", + "token_request_uri": "https://oauth.example.com/token", + "request_content_type": "json", + } + + code_verifier = generate_code_verifier() + state: OAuth2State = { + "database_id": 1, + "user_id": 1, + "default_redirect_uri": "http://localhost:8088/api/v1/oauth2/", + "tab_id": "1234", + "code_verifier": code_verifier, + } + + url = BaseEngineSpec.get_oauth2_authorization_uri(config, state) + parsed = urlparse(url) + query = parse_qs(parsed.query) + + # Verify PKCE parameters are included (RFC 7636) + assert "code_challenge" in query + assert query["code_challenge_method"][0] == "S256" + # Verify the code_challenge matches the expected value + expected_challenge = generate_code_challenge(code_verifier) + assert query["code_challenge"][0] == expected_challenge + + +def test_get_oauth2_token_without_pkce(mocker: MockerFixture) -> None: + """ + Test that BaseEngineSpec.get_oauth2_token works without PKCE code_verifier. + """ + from superset.db_engine_specs.base import BaseEngineSpec + from superset.superset_typing import OAuth2ClientConfig + + mocker.patch( + "flask.current_app.config", + {"DATABASE_OAUTH2_TIMEOUT": mocker.MagicMock(total_seconds=lambda: 30)}, + ) + mock_post = mocker.patch("superset.db_engine_specs.base.requests.post") + mock_post.return_value.json.return_value = { + "access_token": "test-access-token", # noqa: S105 + "expires_in": 3600, + } + + config: OAuth2ClientConfig = { + "id": "client-id", + "secret": "client-secret", + "scope": "read write", + "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/", + "authorization_request_uri": "https://oauth.example.com/authorize", + "token_request_uri": "https://oauth.example.com/token", + "request_content_type": "json", + } + + result = BaseEngineSpec.get_oauth2_token(config, "auth-code") + + assert result["access_token"] == "test-access-token" # noqa: S105 + # Verify code_verifier is NOT in the request body + call_kwargs = mock_post.call_args + request_body = call_kwargs.kwargs.get("json") or call_kwargs.kwargs.get("data") + assert "code_verifier" not in request_body + + +def test_get_oauth2_token_with_pkce(mocker: MockerFixture) -> None: + """ + Test BaseEngineSpec.get_oauth2_token includes code_verifier when provided. + """ + from superset.db_engine_specs.base import BaseEngineSpec + from superset.superset_typing import OAuth2ClientConfig + from superset.utils.oauth2 import generate_code_verifier + + mocker.patch( + "flask.current_app.config", + {"DATABASE_OAUTH2_TIMEOUT": mocker.MagicMock(total_seconds=lambda: 30)}, + ) + mock_post = mocker.patch("superset.db_engine_specs.base.requests.post") + mock_post.return_value.json.return_value = { + "access_token": "test-access-token", # noqa: S105 + "expires_in": 3600, + } + + config: OAuth2ClientConfig = { + "id": "client-id", + "secret": "client-secret", + "scope": "read write", + "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/", + "authorization_request_uri": "https://oauth.example.com/authorize", + "token_request_uri": "https://oauth.example.com/token", + "request_content_type": "json", + } + + code_verifier = generate_code_verifier() + result = BaseEngineSpec.get_oauth2_token(config, "auth-code", code_verifier) + + assert result["access_token"] == "test-access-token" # noqa: S105 + # Verify code_verifier IS in the request body (PKCE) + call_kwargs = mock_post.call_args + request_body = call_kwargs.kwargs.get("json") or call_kwargs.kwargs.get("data") + assert request_body["code_verifier"] == code_verifier diff --git a/tests/unit_tests/utils/oauth2_tests.py b/tests/unit_tests/utils/oauth2_tests.py index e9aa283b1a..0b14e1868a 100644 --- a/tests/unit_tests/utils/oauth2_tests.py +++ b/tests/unit_tests/utils/oauth2_tests.py @@ -17,12 +17,20 @@ # pylint: disable=invalid-name, disallowed-name +import base64 +import hashlib from datetime import datetime from freezegun import freeze_time from pytest_mock import MockerFixture -from superset.utils.oauth2 import get_oauth2_access_token +from superset.utils.oauth2 import ( + decode_oauth2_state, + encode_oauth2_state, + generate_code_challenge, + generate_code_verifier, + get_oauth2_access_token, +) def test_get_oauth2_access_token_base_no_token(mocker: MockerFixture) -> None: @@ -93,3 +101,128 @@ def test_get_oauth2_access_token_base_no_refresh(mocker: MockerFixture) -> None: # check that token was deleted db.session.delete.assert_called_with(token) + + +def test_generate_code_verifier_length() -> None: + """ + Test that generate_code_verifier produces a string of valid length (RFC 7636). + """ + code_verifier = generate_code_verifier() + # RFC 7636 requires 43-128 characters + assert 43 <= len(code_verifier) <= 128 + + +def test_generate_code_verifier_uniqueness() -> None: + """ + Test that generate_code_verifier produces unique values. + """ + verifiers = {generate_code_verifier() for _ in range(100)} + # All generated verifiers should be unique + assert len(verifiers) == 100 + + +def test_generate_code_verifier_valid_characters() -> None: + """ + Test that generate_code_verifier only uses valid characters (RFC 7636). + """ + code_verifier = generate_code_verifier() + # RFC 7636 allows: [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~" + # URL-safe base64 uses: [A-Z] / [a-z] / [0-9] / "-" / "_" + valid_chars = set( + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_" + ) + assert all(char in valid_chars for char in code_verifier) + + +def test_generate_code_challenge_s256() -> None: + """ + Test that generate_code_challenge produces correct S256 challenge. + """ + # Use a known code_verifier to verify the challenge computation + code_verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + + # Compute expected challenge manually + digest = hashlib.sha256(code_verifier.encode("ascii")).digest() + expected_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") + + code_challenge = generate_code_challenge(code_verifier) + assert code_challenge == expected_challenge + + +def test_generate_code_challenge_rfc_example() -> None: + """ + Test PKCE code challenge against RFC 7636 Appendix B example. + + See: https://datatracker.ietf.org/doc/html/rfc7636#appendix-B + """ + # RFC 7636 example code_verifier (Appendix B) + code_verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + # RFC 7636 expected code_challenge for S256 method + expected_challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM" + + code_challenge = generate_code_challenge(code_verifier) + assert code_challenge == expected_challenge + + +def test_encode_decode_oauth2_state_with_code_verifier(mocker: MockerFixture) -> None: + """ + Test that code_verifier is preserved through encode/decode cycle. + """ + from superset.superset_typing import OAuth2State + + mocker.patch( + "flask.current_app.config", + { + "SECRET_KEY": "test-secret-key", + "DATABASE_OAUTH2_JWT_ALGORITHM": "HS256", + }, + ) + + code_verifier = generate_code_verifier() + state: OAuth2State = { + "database_id": 1, + "user_id": 2, + "default_redirect_uri": "http://localhost:8088/api/v1/oauth2/", + "tab_id": "test-tab-id", + "code_verifier": code_verifier, + } + + with freeze_time("2024-01-01"): + encoded = encode_oauth2_state(state) + decoded = decode_oauth2_state(encoded) + + assert decoded["code_verifier"] == code_verifier + assert decoded["database_id"] == 1 + assert decoded["user_id"] == 2 + + +def test_encode_decode_oauth2_state_without_code_verifier( + mocker: MockerFixture, +) -> None: + """ + Test backward compatibility: state without code_verifier still works. + """ + from superset.superset_typing import OAuth2State + + mocker.patch( + "flask.current_app.config", + { + "SECRET_KEY": "test-secret-key", + "DATABASE_OAUTH2_JWT_ALGORITHM": "HS256", + }, + ) + + state: OAuth2State = { + "database_id": 1, + "user_id": 2, + "default_redirect_uri": "http://localhost:8088/api/v1/oauth2/", + "tab_id": "test-tab-id", + } + + with freeze_time("2024-01-01"): + encoded = encode_oauth2_state(state) + decoded = decode_oauth2_state(encoded) + + assert "code_verifier" not in decoded + assert decoded["database_id"] == 1 + assert decoded["user_id"] == 2
