This is an automated email from the ASF dual-hosted git repository. betodealmeida pushed a commit to branch oauth-during-db-creation in repository https://gitbox.apache.org/repos/asf/superset.git
commit c8af308afbc781b33da01564e80b37b847928964 Author: Beto Dealmeida <[email protected]> AuthorDate: Wed May 20 17:37:46 2026 -0400 feat: OAuth2 during DB creation --- .../src/features/databases/DatabaseModal/index.tsx | 56 ++++++++ superset/commands/database/oauth2.py | 158 +++++++++++++++------ superset/commands/database/test_connection.py | 8 ++ superset/databases/schemas.py | 10 ++ superset/db_engine_specs/base.py | 27 +++- superset/key_value/types.py | 1 + superset/superset_typing.py | 10 +- superset/utils/oauth2.py | 62 +++++++- tests/unit_tests/commands/databases/oauth2_test.py | 77 +++++++++- tests/unit_tests/utils/oauth2_tests.py | 77 ++++++++++ 10 files changed, 438 insertions(+), 48 deletions(-) diff --git a/superset-frontend/src/features/databases/DatabaseModal/index.tsx b/superset-frontend/src/features/databases/DatabaseModal/index.tsx index 1fa38c5e152..0600e16d57a 100644 --- a/superset-frontend/src/features/databases/DatabaseModal/index.tsx +++ b/superset-frontend/src/features/databases/DatabaseModal/index.tsx @@ -622,6 +622,15 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({ const [editNewDb, setEditNewDb] = useState<boolean>(false); const [isLoading, setLoading] = useState<boolean>(false); const [testInProgress, setTestInProgress] = useState<boolean>(false); + // Stable id sent on every "Test Connection" call so that the backend can + // correlate the wizard's OAuth2 dance across requests. Generated once per + // modal session — keeping the same value lets the second test_connection + // call (after the user finishes the OAuth2 dance) find the cached token. + const oauth2TabIdRef = useRef<string>( + typeof crypto !== 'undefined' && crypto.randomUUID + ? crypto.randomUUID() + : Math.random().toString(36).slice(2), + ); const [passwords, setPasswords] = useState<Record<string, string>>({}); const [sshTunnelPasswords, setSSHTunnelPasswords] = useState< Record<string, string> @@ -727,6 +736,11 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({ }, [setValidationErrors, setHasValidated, clearError]); // Test Connection logic + // keep the latest ``testConnection`` callable accessible to long-lived + // listeners (BroadcastChannel/storage) without making them depend on every + // render of the modal. + const testConnectionRef = useRef<() => void>(() => {}); + const testConnection = () => { handleClearValidationErrors(); if (!db?.sqlalchemy_uri) { @@ -748,6 +762,7 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({ server_port: Number(db.ssh_tunnel!.server_port), } : undefined, + oauth2_tab_id: oauth2TabIdRef.current, }; setTestInProgress(true); testDatabaseConnection( @@ -764,6 +779,47 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> = ({ }, ); }; + testConnectionRef.current = testConnection; + + // Re-run "Test Connection" automatically when the OAuth2 dance completes + // in the popup tab. The dance posts a message on the ``oauth`` broadcast + // channel (and a localStorage event for cross-context delivery) carrying + // the wizard's own tab_id; we react only to our own. + useEffect(() => { + const tabId = oauth2TabIdRef.current; + + const handleComplete = (incomingTabId?: string) => { + if (incomingTabId === tabId) { + testConnectionRef.current(); + } + }; + + const channel = + typeof BroadcastChannel !== 'undefined' + ? new BroadcastChannel('oauth') + : null; + if (channel) { + channel.onmessage = event => handleComplete(event.data?.tabId); + } + + const handleStorage = (event: StorageEvent) => { + if (event.key !== 'oauth2_auth_complete' || !event.newValue) { + return; + } + try { + const payload = JSON.parse(event.newValue) as { tabId?: string }; + handleComplete(payload.tabId); + } catch { + /* ignore */ + } + }; + window.addEventListener('storage', handleStorage); + + return () => { + window.removeEventListener('storage', handleStorage); + channel?.close(); + }; + }, []); const getPlaceholder = (field: string) => { if (field === 'database') { diff --git a/superset/commands/database/oauth2.py b/superset/commands/database/oauth2.py index 8355bc0098e..2df792bf429 100644 --- a/superset/commands/database/oauth2.py +++ b/superset/commands/database/oauth2.py @@ -25,66 +25,157 @@ from superset.commands.database.exceptions import DatabaseNotFoundError from superset.daos.database import DatabaseUserOAuth2TokensDAO from superset.daos.key_value import KeyValueDAO from superset.databases.schemas import OAuth2ProviderResponseSchema +from superset.db_engine_specs import get_engine_spec +from superset.db_engine_specs.base import BaseEngineSpec from superset.exceptions import OAuth2Error from superset.key_value.types import JsonKeyValueCodec, KeyValueResource from superset.models.core import Database, DatabaseUserOAuth2Tokens -from superset.superset_typing import OAuth2State +from superset.superset_typing import ( + OAuth2ClientConfig, + OAuth2State, + OAuth2TokenResponse, +) from superset.utils.decorators import on_error, transaction from superset.utils.oauth2 import decode_oauth2_state +# how long the pre-create token cache lives in the KV store +PRE_CREATE_TOKEN_TTL = timedelta(minutes=5) + class OAuth2StoreTokenCommand(BaseCommand): """ - Command to store OAuth2 tokens in the database. + Command to store OAuth2 tokens. + + Normal flow: the OAuth2 callback resolves the database via ``state.database_id`` + and persists access/refresh tokens to ``database_user_oauth2_tokens``. + + Pre-create flow: when ``state.database_id`` is ``None`` (the database hasn't + been saved yet — typically the "Create database" wizard), the command reads + the OAuth2 client config and engine name from the KV store entry that + :meth:`BaseEngineSpec.start_oauth2_dance` stashed there, exchanges the code, + and caches the resulting access token in the same KV entry for + :func:`get_oauth2_access_token` to pick up on the retry. """ def __init__(self, parameters: OAuth2ProviderResponseSchema): self._parameters = parameters self._state: OAuth2State | None = None self._database: Database | None = None + self._oauth2_config: OAuth2ClientConfig | None = None + self._engine_spec: type[BaseEngineSpec] | None = None + self._tab_uuid: UUID | None = None @transaction(on_error=partial(on_error, reraise=OAuth2Error)) - def run(self) -> DatabaseUserOAuth2Tokens: + def run(self) -> DatabaseUserOAuth2Tokens | None: self.validate() - self._database = cast(Database, self._database) self._state = cast(OAuth2State, self._state) - - oauth2_config = self._database.get_oauth2_config() - if oauth2_config is None: - raise OAuth2Error("No configuration found for OAuth2") + self._oauth2_config = cast(OAuth2ClientConfig, self._oauth2_config) + self._engine_spec = cast(type[BaseEngineSpec], self._engine_spec) # Look up PKCE code_verifier from KV store (RFC 7636) - code_verifier = None - tab_id = self._state["tab_id"] + code_verifier = self._pop_code_verifier() + + token_response = self._engine_spec.get_oauth2_token( + self._oauth2_config, + self._parameters["code"], + code_verifier=code_verifier, + ) + + if self._database is None: + # Pre-create flow: cache the access token in the KV entry the + # initial dance created. The retry of "Test Connection" will read + # it via ``get_oauth2_access_token``. + self._cache_pre_create_token(token_response) + return None + + return self._persist_token(token_response) + + def validate(self) -> None: + if error := self._parameters.get("error"): + raise OAuth2Error(error) + + self._state = decode_oauth2_state(self._parameters["state"]) + try: - tab_uuid = UUID(tab_id) - except ValueError: - tab_uuid = None - - if tab_uuid: - kv_value = KeyValueDAO.get_value( - resource=KeyValueResource.PKCE_CODE_VERIFIER, - key=tab_uuid, + self._tab_uuid = UUID(self._state["tab_id"]) + except (KeyError, ValueError): + # Legacy paths may use non-UUID tab ids; we still want to support + # them when ``database_id`` is set. The pre-create path below + # requires a valid UUID. + self._tab_uuid = None + + if database_id := self._state.get("database_id"): + self._database = DatabaseUserOAuth2TokensDAO.get_database(database_id) + if self._database is None: + raise DatabaseNotFoundError("Database not found") + self._oauth2_config = self._database.get_oauth2_config() + self._engine_spec = self._database.db_engine_spec + else: + if self._tab_uuid is None: + raise OAuth2Error( + "Pre-create OAuth2 callback requires a UUID tab_id", + ) + cached = KeyValueDAO.get_value( + resource=KeyValueResource.OAUTH2_PRE_CREATE_TOKEN, + key=self._tab_uuid, codec=JsonKeyValueCodec(), ) - if kv_value: - code_verifier = kv_value.get("code_verifier") - KeyValueDAO.delete_entry(KeyValueResource.PKCE_CODE_VERIFIER, tab_uuid) + if not cached or not cached.get("config"): + raise OAuth2Error("Pre-create OAuth2 context not found or expired") + self._oauth2_config = cast(OAuth2ClientConfig, cached["config"]) + engine = self._state.get("engine") or cached.get("engine") + if not engine: + raise OAuth2Error("Pre-create OAuth2 context missing engine name") + self._engine_spec = get_engine_spec(engine) - token_response = self._database.db_engine_spec.get_oauth2_token( - oauth2_config, - self._parameters["code"], - code_verifier=code_verifier, + if self._oauth2_config is None: + raise OAuth2Error("No configuration found for OAuth2") + + def _pop_code_verifier(self) -> str | None: + if self._tab_uuid is None: + return None + kv_value = KeyValueDAO.get_value( + resource=KeyValueResource.PKCE_CODE_VERIFIER, + key=self._tab_uuid, + codec=JsonKeyValueCodec(), ) + if not kv_value: + return None + KeyValueDAO.delete_entry(KeyValueResource.PKCE_CODE_VERIFIER, self._tab_uuid) + return kv_value.get("code_verifier") + + def _cache_pre_create_token(self, token_response: OAuth2TokenResponse) -> None: + self._state = cast(OAuth2State, self._state) + self._tab_uuid = cast(UUID, self._tab_uuid) + self._oauth2_config = cast(OAuth2ClientConfig, self._oauth2_config) + self._engine_spec = cast(type[BaseEngineSpec], self._engine_spec) + + expires_on = datetime.now() + PRE_CREATE_TOKEN_TTL + KeyValueDAO.upsert_entry( + resource=KeyValueResource.OAUTH2_PRE_CREATE_TOKEN, + key=self._tab_uuid, + value={ + "engine": self._engine_spec.engine, + "config": self._oauth2_config, + "user_id": self._state["user_id"], + "access_token": token_response["access_token"], + }, + codec=JsonKeyValueCodec(), + expires_on=expires_on, + ) + + def _persist_token( + self, + token_response: OAuth2TokenResponse, + ) -> DatabaseUserOAuth2Tokens: + self._state = cast(OAuth2State, self._state) - # delete old tokens if existing := DatabaseUserOAuth2TokensDAO.find_one_or_none( user_id=self._state["user_id"], database_id=self._state["database_id"], ): DatabaseUserOAuth2TokensDAO.delete([existing]) - # store tokens expiration = datetime.now() + timedelta(seconds=token_response["expires_in"]) return DatabaseUserOAuth2TokensDAO.create( attributes={ @@ -95,16 +186,3 @@ class OAuth2StoreTokenCommand(BaseCommand): "refresh_token": token_response.get("refresh_token"), }, ) - - def validate(self) -> None: - if error := self._parameters.get("error"): - raise OAuth2Error(error) - - self._state = decode_oauth2_state(self._parameters["state"]) - - if database := DatabaseUserOAuth2TokensDAO.get_database( - self._state["database_id"] - ): - self._database = database - else: - raise DatabaseNotFoundError("Database not found") diff --git a/superset/commands/database/test_connection.py b/superset/commands/database/test_connection.py index 1395994d73e..fafeacc8b58 100644 --- a/superset/commands/database/test_connection.py +++ b/superset/commands/database/test_connection.py @@ -17,6 +17,7 @@ import logging from typing import Any, Optional +from flask import g from flask_babel import gettext as _ from sqlalchemy.exc import DBAPIError, NoSuchModuleError @@ -94,6 +95,13 @@ class TestConnectionDatabaseCommand(BaseCommand): self.validate() ex_str = "" + # Surface the wizard's tab_id (sent by the frontend) so that + # ``get_oauth2_access_token`` can find the pre-create OAuth2 token + # cached in the KV store, and so that ``start_oauth2_dance`` reuses + # this id instead of generating a new one. + if oauth2_tab_id := self._properties.get("oauth2_tab_id"): + g.oauth2_tab_id = oauth2_tab_id + url = make_url_safe(self._uri) engine_name = url.get_backend_name() diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index b399a25ad7c..3acf82924d6 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -623,6 +623,16 @@ class DatabaseTestConnectionSchema(DatabaseParametersSchemaMixin, Schema): ) ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True) + oauth2_tab_id = fields.String( + metadata={ + "description": ( + "UUID identifying the wizard tab for the pre-create OAuth2 flow." + " Optional; when supplied, the engine will look up the pre-create" + " OAuth2 token cached under this key." + ), + }, + allow_none=True, + ) class TableMetadataOptionsResponse(TypedDict): diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index de9c6c12302..984b4c3fe3b 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -664,7 +664,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # Prevent circular import. from superset.daos.key_value import KeyValueDAO - tab_id = str(uuid4()) + # Reuse the wizard-supplied tab_id when present so the retry of + # "Test Connection" can find the same KV-cached token. Otherwise we + # generate a fresh one as before. + tab_id = getattr(g, "oauth2_tab_id", None) or str(uuid4()) default_redirect_uri = get_oauth2_redirect_uri() # Generate PKCE code verifier (RFC 7636) @@ -690,6 +693,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # belongs to. state: OAuth2State = { # Database ID and user ID are the primary key associated with the token. + # ``database_id`` is ``None`` during the "Create database" wizard — the + # callback caches the access token in the KV store instead of inserting + # a row in ``database_user_oauth2_tokens``. "database_id": database.id, "user_id": g.user.id, # In multi-instance deployments there might be a single proxy handling @@ -710,6 +716,25 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods if oauth2_config is None: raise OAuth2Error("No configuration found for OAuth2") + # Pre-create flow: the database has no row yet, so we can't look it up by id + # in the callback. Stash the engine spec + oauth2 config in the KV store + # alongside the existing PKCE entry, keyed by the same ``tab_id``. The + # callback reads this to exchange the code without a persisted ``Database``. + if database.id is None: + state["engine"] = cls.engine + KeyValueDAO.delete_expired_entries(KeyValueResource.OAUTH2_PRE_CREATE_TOKEN) + KeyValueDAO.create_entry( + resource=KeyValueResource.OAUTH2_PRE_CREATE_TOKEN, + value={ + "engine": cls.engine, + "config": oauth2_config, + }, + codec=JsonKeyValueCodec(), + key=UUID(tab_id), + expires_on=datetime.now() + timedelta(minutes=5), + ) + db.session.commit() + oauth_url = cls.get_oauth2_authorization_uri( oauth2_config, state, diff --git a/superset/key_value/types.py b/superset/key_value/types.py index 2cc025e4265..352c6669618 100644 --- a/superset/key_value/types.py +++ b/superset/key_value/types.py @@ -46,6 +46,7 @@ class KeyValueResource(StrEnum): METASTORE_CACHE = "superset_metastore_cache" LOCK = "lock" PKCE_CODE_VERIFIER = "pkce_code_verifier" + OAUTH2_PRE_CREATE_TOKEN = "oauth2_pre_create_token" # noqa: S105 SQLLAB_PERMALINK = "sqllab_permalink" diff --git a/superset/superset_typing.py b/superset/superset_typing.py index 39ff9ab1347..b2f2f6000dd 100644 --- a/superset/superset_typing.py +++ b/superset/superset_typing.py @@ -402,7 +402,15 @@ class OAuth2State(TypedDict, total=False): Type for the state passed during OAuth2. """ - database_id: int + # ``database_id`` is ``None`` during the "Create database" wizard, where the + # OAuth2 dance runs before the database has been persisted. In that case the + # access token is cached in the KV store keyed by ``tab_id`` until the user + # saves the database. + database_id: int | None user_id: int default_redirect_uri: str tab_id: str + # Engine backend code (e.g. ``"semanticapi"``), present only for pre-create + # dances so the callback can resolve the engine spec without a persisted + # database row. + engine: str diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py index 020f5397b3c..d77102235ce 100644 --- a/superset/utils/oauth2.py +++ b/superset/utils/oauth2.py @@ -24,10 +24,11 @@ import secrets from contextlib import contextmanager from datetime import datetime, timedelta, timezone from typing import Any, Iterator, TYPE_CHECKING +from uuid import UUID import backoff import jwt -from flask import current_app as app, url_for +from flask import current_app as app, g, url_for from marshmallow import EXCLUDE, fields, post_load, Schema, validate from werkzeug.routing import BuildError @@ -87,7 +88,7 @@ def generate_code_challenge(code_verifier: str) -> str: ) def get_oauth2_access_token( config: OAuth2ClientConfig, - database_id: int, + database_id: int | None, user_id: int, db_engine_spec: type[BaseEngineSpec], ) -> str | None: @@ -100,10 +101,18 @@ def get_oauth2_access_token( simultaneous requests for refreshing a stale token; in that case only the first process to acquire the lock will perform the refresh, and othe process should find a a valid token when they retry. + + When ``database_id`` is ``None`` (the "Create database" wizard, where the + database hasn't been persisted yet), look up the token in the KV store under + ``g.oauth2_tab_id``. The token was cached there by + :class:`OAuth2StoreTokenCommand` after the wizard's OAuth2 dance. """ # noqa: E501 # pylint: disable=import-outside-toplevel from superset.models.core import DatabaseUserOAuth2Tokens + if not database_id: + return _get_pre_create_access_token(user_id) + token = ( db.session.query(DatabaseUserOAuth2Tokens) .filter_by(user_id=user_id, database_id=database_id) @@ -124,6 +133,39 @@ def get_oauth2_access_token( return None +def _get_pre_create_access_token(user_id: int) -> str | None: + """ + Look up a pre-create OAuth2 token in the KV store. + + The KV entry is keyed by the wizard's ``tab_id``; the frontend passes it on + every "Test Connection" request, and the test-connection command exposes it + via :data:`flask.g.oauth2_tab_id`. + """ + # pylint: disable=import-outside-toplevel + from superset.daos.key_value import KeyValueDAO + from superset.key_value.types import JsonKeyValueCodec, KeyValueResource + + tab_id = getattr(g, "oauth2_tab_id", None) + if not tab_id: + return None + try: + tab_uuid = UUID(tab_id) + except (TypeError, ValueError): + return None + + cached = KeyValueDAO.get_value( + resource=KeyValueResource.OAUTH2_PRE_CREATE_TOKEN, + key=tab_uuid, + codec=JsonKeyValueCodec(), + ) + if not cached: + return None + if cached.get("user_id") != user_id: + # the tab_id is the user's own, but be defensive + return None + return cached.get("access_token") + + def refresh_oauth2_token( config: OAuth2ClientConfig, database_id: int, @@ -203,14 +245,20 @@ def refresh_oauth2_token( def encode_oauth2_state(state: OAuth2State) -> str: """ Encode the OAuth2 state. + + ``database_id`` is ``None`` for the pre-create-database OAuth2 dance, + which means the token won't be persisted to ``database_user_oauth2_tokens`` + when the callback runs — see :class:`OAuth2StoreTokenCommand`. """ payload: dict[str, Any] = { "exp": datetime.now(tz=timezone.utc) + JWT_EXPIRATION, - "database_id": state["database_id"], + "database_id": state.get("database_id"), "user_id": state["user_id"], "default_redirect_uri": state["default_redirect_uri"], "tab_id": state["tab_id"], } + if engine := state.get("engine"): + payload["engine"] = engine encoded_state = jwt.encode( payload=payload, @@ -225,10 +273,11 @@ def encode_oauth2_state(state: OAuth2State) -> str: class OAuth2StateSchema(Schema): - database_id = fields.Int(required=True) + database_id = fields.Int(required=True, allow_none=True) user_id = fields.Int(required=True) default_redirect_uri = fields.Str(required=True) tab_id = fields.Str(required=True) + engine = fields.Str(required=False, load_default=None) # pylint: disable=unused-argument @post_load @@ -237,12 +286,15 @@ class OAuth2StateSchema(Schema): data: dict[str, Any], **kwargs: Any, ) -> OAuth2State: - return { + state: OAuth2State = { "database_id": data["database_id"], "user_id": data["user_id"], "default_redirect_uri": data["default_redirect_uri"], "tab_id": data["tab_id"], } + if data.get("engine"): + state["engine"] = data["engine"] + return state class Meta: # pylint: disable=too-few-public-methods # ignore `exp` diff --git a/tests/unit_tests/commands/databases/oauth2_test.py b/tests/unit_tests/commands/databases/oauth2_test.py index 0fbe2035d29..00b217cae09 100644 --- a/tests/unit_tests/commands/databases/oauth2_test.py +++ b/tests/unit_tests/commands/databases/oauth2_test.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any +from typing import Any, cast from unittest.mock import MagicMock import pytest @@ -27,6 +27,7 @@ from superset.daos.database import DatabaseUserOAuth2TokensDAO from superset.databases.schemas import OAuth2ProviderResponseSchema from superset.exceptions import OAuth2Error from superset.models.core import Database +from superset.superset_typing import OAuth2State from superset.utils.oauth2 import decode_oauth2_state, encode_oauth2_state @@ -135,6 +136,80 @@ def test_run_success( mock_create.assert_called_once() +def test_run_pre_create_caches_token_in_kv(mocker: MockerFixture) -> None: + """ + With ``state.database_id is None`` the command reads the engine + config + from the pre-create KV entry, exchanges the code, and writes the token + back to KV — *not* to ``database_user_oauth2_tokens``. + """ + state: OAuth2State = { + "user_id": 1, + "database_id": None, + "default_redirect_uri": "http://localhost:8088/api/v1/oauth2/", + "tab_id": "3a3a3a3a-3a3a-3a3a-3a3a-3a3a3a3a3a3a", + "engine": "semanticapi", + } + parameters: dict[str, Any] = { + "code": "the-code", + "state": encode_oauth2_state(state), + } + + kv_dao = mocker.patch("superset.commands.database.oauth2.KeyValueDAO") + kv_dao.get_value.side_effect = [ + {"engine": "semanticapi", "config": {"id": "x", "secret": "y"}}, # validate() + None, # code_verifier lookup + ] + + engine_spec = mocker.MagicMock() + engine_spec.engine = "semanticapi" + engine_spec.get_oauth2_token.return_value = { + "access_token": "fresh-token", + "expires_in": 3600, + "refresh_token": "refresh", + } + mocker.patch( + "superset.commands.database.oauth2.get_engine_spec", + return_value=engine_spec, + ) + + dao_get_db = mocker.patch.object(DatabaseUserOAuth2TokensDAO, "get_database") + dao_create = mocker.patch.object(DatabaseUserOAuth2TokensDAO, "create") + + result = OAuth2StoreTokenCommand( + cast(OAuth2ProviderResponseSchema, parameters), + ).run() + + assert result is None + dao_get_db.assert_not_called() + dao_create.assert_not_called() + kv_dao.upsert_entry.assert_called_once() + upsert_kwargs = kv_dao.upsert_entry.call_args.kwargs + assert upsert_kwargs["value"]["access_token"] == "fresh-token" # noqa: S105 + assert upsert_kwargs["value"]["user_id"] == 1 + + +def test_run_pre_create_missing_kv_entry(mocker: MockerFixture) -> None: + """ + Pre-create flow with no cached entry should fail with a clear OAuth2Error. + """ + state: OAuth2State = { + "user_id": 1, + "database_id": None, + "default_redirect_uri": "http://localhost:8088/api/v1/oauth2/", + "tab_id": "3a3a3a3a-3a3a-3a3a-3a3a-3a3a3a3a3a3a", + } + parameters: dict[str, Any] = {"code": "x", "state": encode_oauth2_state(state)} + + kv_dao = mocker.patch("superset.commands.database.oauth2.KeyValueDAO") + kv_dao.get_value.return_value = None + + with pytest.raises(OAuth2Error) as exc_info: + OAuth2StoreTokenCommand( + cast(OAuth2ProviderResponseSchema, parameters), + ).validate() + assert "Pre-create OAuth2 context" in exc_info.value.error.extra["error"] + + def test_run_existing_token( mocker: MockerFixture, mock_database: MagicMock, diff --git a/tests/unit_tests/utils/oauth2_tests.py b/tests/unit_tests/utils/oauth2_tests.py index ac788ce66e5..a8bddf6c3cb 100644 --- a/tests/unit_tests/utils/oauth2_tests.py +++ b/tests/unit_tests/utils/oauth2_tests.py @@ -458,6 +458,83 @@ def test_encode_decode_oauth2_state( assert decoded["user_id"] == 2 +def test_encode_decode_oauth2_state_pre_create(mocker: MockerFixture) -> None: + """ + The pre-create dance encodes ``database_id=None`` and an ``engine`` + field; both must survive the round-trip so the callback can resolve the + engine spec without a persisted database. + """ + from superset.superset_typing import OAuth2State + + mocker.patch( + "flask.current_app.config", + {"SECRET_KEY": "test-secret-key", "DATABASE_OAUTH2_JWT_ALGORITHM": "HS256"}, + ) + + state: OAuth2State = { + "database_id": None, + "user_id": 2, + "default_redirect_uri": "http://localhost:8088/api/v1/oauth2/", + "tab_id": "abc", + "engine": "semanticapi", + } + with freeze_time("2024-01-01"): + decoded = decode_oauth2_state(encode_oauth2_state(state)) + + assert decoded["database_id"] is None + assert decoded["engine"] == "semanticapi" + + +def test_get_oauth2_access_token_pre_create_hit(mocker: MockerFixture) -> None: + """ + With ``database_id=None`` and a matching KV entry, the cached pre-create + token is returned. + """ + mocker.patch( + "superset.utils.oauth2.g", + oauth2_tab_id="3a3a3a3a-3a3a-3a3a-3a3a-3a3a3a3a3a3a", + ) + kv_dao = mocker.patch("superset.daos.key_value.KeyValueDAO") + kv_dao.get_value.return_value = { + "access_token": "cached-token", + "user_id": 7, + } + + assert ( + get_oauth2_access_token(DUMMY_OAUTH2_CONFIG, None, 7, mocker.MagicMock()) + == "cached-token" + ) + + +def test_get_oauth2_access_token_pre_create_miss(mocker: MockerFixture) -> None: + """ + Without a tab id (or with no KV entry / wrong user) the lookup returns None. + """ + mocker.patch("superset.utils.oauth2.g", oauth2_tab_id=None) + assert ( + get_oauth2_access_token(DUMMY_OAUTH2_CONFIG, None, 7, mocker.MagicMock()) + is None + ) + + mocker.patch( + "superset.utils.oauth2.g", + oauth2_tab_id="3a3a3a3a-3a3a-3a3a-3a3a-3a3a3a3a3a3a", + ) + kv_dao = mocker.patch("superset.daos.key_value.KeyValueDAO") + kv_dao.get_value.return_value = None + assert ( + get_oauth2_access_token(DUMMY_OAUTH2_CONFIG, None, 7, mocker.MagicMock()) + is None + ) + + kv_dao.get_value.return_value = {"access_token": "x", "user_id": 99} + # token belongs to another user + assert ( + get_oauth2_access_token(DUMMY_OAUTH2_CONFIG, None, 7, mocker.MagicMock()) + is None + ) + + def test_get_oauth2_access_token_lock_not_acquired_no_error_log( mocker: MockerFixture, caplog: pytest.LogCaptureFixture,
