This is an automated email from the ASF dual-hosted git repository.

betodealmeida pushed a commit to branch oauth-during-db-creation
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 576a180ba3f0a9a3ed9e01304d63f9a6de0725e4
Author: Beto Dealmeida <[email protected]>
AuthorDate: Wed May 20 17:37:46 2026 -0400

    feat: OAuth2 during DB creation
---
 .../src/features/databases/DatabaseModal/index.tsx |  56 ++++++++
 superset/commands/database/oauth2.py               | 158 +++++++++++++++------
 superset/commands/database/test_connection.py      |   8 ++
 superset/databases/schemas.py                      |  10 ++
 superset/db_engine_specs/base.py                   |  27 +++-
 superset/key_value/types.py                        |   1 +
 superset/superset_typing.py                        |  10 +-
 superset/utils/oauth2.py                           |  62 +++++++-
 tests/unit_tests/commands/databases/oauth2_test.py |  77 +++++++++-
 tests/unit_tests/utils/oauth2_tests.py             |  77 ++++++++++
 10 files changed, 438 insertions(+), 48 deletions(-)

diff --git a/superset-frontend/src/features/databases/DatabaseModal/index.tsx 
b/superset-frontend/src/features/databases/DatabaseModal/index.tsx
index 1fa38c5e152..0600e16d57a 100644
--- a/superset-frontend/src/features/databases/DatabaseModal/index.tsx
+++ b/superset-frontend/src/features/databases/DatabaseModal/index.tsx
@@ -622,6 +622,15 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> 
= ({
   const [editNewDb, setEditNewDb] = useState<boolean>(false);
   const [isLoading, setLoading] = useState<boolean>(false);
   const [testInProgress, setTestInProgress] = useState<boolean>(false);
+  // Stable id sent on every "Test Connection" call so that the backend can
+  // correlate the wizard's OAuth2 dance across requests. Generated once per
+  // modal session — keeping the same value lets the second test_connection
+  // call (after the user finishes the OAuth2 dance) find the cached token.
+  const oauth2TabIdRef = useRef<string>(
+    typeof crypto !== 'undefined' && crypto.randomUUID
+      ? crypto.randomUUID()
+      : Math.random().toString(36).slice(2),
+  );
   const [passwords, setPasswords] = useState<Record<string, string>>({});
   const [sshTunnelPasswords, setSSHTunnelPasswords] = useState<
     Record<string, string>
@@ -727,6 +736,11 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> 
= ({
   }, [setValidationErrors, setHasValidated, clearError]);
 
   // Test Connection logic
+  // keep the latest ``testConnection`` callable accessible to long-lived
+  // listeners (BroadcastChannel/storage) without making them depend on every
+  // render of the modal.
+  const testConnectionRef = useRef<() => void>(() => {});
+
   const testConnection = () => {
     handleClearValidationErrors();
     if (!db?.sqlalchemy_uri) {
@@ -748,6 +762,7 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> 
= ({
               server_port: Number(db.ssh_tunnel!.server_port),
             }
           : undefined,
+      oauth2_tab_id: oauth2TabIdRef.current,
     };
     setTestInProgress(true);
     testDatabaseConnection(
@@ -764,6 +779,47 @@ const DatabaseModal: FunctionComponent<DatabaseModalProps> 
= ({
       },
     );
   };
+  testConnectionRef.current = testConnection;
+
+  // Re-run "Test Connection" automatically when the OAuth2 dance completes
+  // in the popup tab. The dance posts a message on the ``oauth`` broadcast
+  // channel (and a localStorage event for cross-context delivery) carrying
+  // the wizard's own tab_id; we react only to our own.
+  useEffect(() => {
+    const tabId = oauth2TabIdRef.current;
+
+    const handleComplete = (incomingTabId?: string) => {
+      if (incomingTabId === tabId) {
+        testConnectionRef.current();
+      }
+    };
+
+    const channel =
+      typeof BroadcastChannel !== 'undefined'
+        ? new BroadcastChannel('oauth')
+        : null;
+    if (channel) {
+      channel.onmessage = event => handleComplete(event.data?.tabId);
+    }
+
+    const handleStorage = (event: StorageEvent) => {
+      if (event.key !== 'oauth2_auth_complete' || !event.newValue) {
+        return;
+      }
+      try {
+        const payload = JSON.parse(event.newValue) as { tabId?: string };
+        handleComplete(payload.tabId);
+      } catch {
+        /* ignore */
+      }
+    };
+    window.addEventListener('storage', handleStorage);
+
+    return () => {
+      window.removeEventListener('storage', handleStorage);
+      channel?.close();
+    };
+  }, []);
 
   const getPlaceholder = (field: string) => {
     if (field === 'database') {
diff --git a/superset/commands/database/oauth2.py 
b/superset/commands/database/oauth2.py
index 8355bc0098e..2df792bf429 100644
--- a/superset/commands/database/oauth2.py
+++ b/superset/commands/database/oauth2.py
@@ -25,66 +25,157 @@ from superset.commands.database.exceptions import 
DatabaseNotFoundError
 from superset.daos.database import DatabaseUserOAuth2TokensDAO
 from superset.daos.key_value import KeyValueDAO
 from superset.databases.schemas import OAuth2ProviderResponseSchema
+from superset.db_engine_specs import get_engine_spec
+from superset.db_engine_specs.base import BaseEngineSpec
 from superset.exceptions import OAuth2Error
 from superset.key_value.types import JsonKeyValueCodec, KeyValueResource
 from superset.models.core import Database, DatabaseUserOAuth2Tokens
-from superset.superset_typing import OAuth2State
+from superset.superset_typing import (
+    OAuth2ClientConfig,
+    OAuth2State,
+    OAuth2TokenResponse,
+)
 from superset.utils.decorators import on_error, transaction
 from superset.utils.oauth2 import decode_oauth2_state
 
+# how long the pre-create token cache lives in the KV store
+PRE_CREATE_TOKEN_TTL = timedelta(minutes=5)
+
 
 class OAuth2StoreTokenCommand(BaseCommand):
     """
-    Command to store OAuth2 tokens in the database.
+    Command to store OAuth2 tokens.
+
+    Normal flow: the OAuth2 callback resolves the database via 
``state.database_id``
+    and persists access/refresh tokens to ``database_user_oauth2_tokens``.
+
+    Pre-create flow: when ``state.database_id`` is ``None`` (the database 
hasn't
+    been saved yet — typically the "Create database" wizard), the command reads
+    the OAuth2 client config and engine name from the KV store entry that
+    :meth:`BaseEngineSpec.start_oauth2_dance` stashed there, exchanges the 
code,
+    and caches the resulting access token in the same KV entry for
+    :func:`get_oauth2_access_token` to pick up on the retry.
     """
 
     def __init__(self, parameters: OAuth2ProviderResponseSchema):
         self._parameters = parameters
         self._state: OAuth2State | None = None
         self._database: Database | None = None
+        self._oauth2_config: OAuth2ClientConfig | None = None
+        self._engine_spec: type[BaseEngineSpec] | None = None
+        self._tab_uuid: UUID | None = None
 
     @transaction(on_error=partial(on_error, reraise=OAuth2Error))
-    def run(self) -> DatabaseUserOAuth2Tokens:
+    def run(self) -> DatabaseUserOAuth2Tokens | None:
         self.validate()
-        self._database = cast(Database, self._database)
         self._state = cast(OAuth2State, self._state)
-
-        oauth2_config = self._database.get_oauth2_config()
-        if oauth2_config is None:
-            raise OAuth2Error("No configuration found for OAuth2")
+        self._oauth2_config = cast(OAuth2ClientConfig, self._oauth2_config)
+        self._engine_spec = cast(type[BaseEngineSpec], self._engine_spec)
 
         # Look up PKCE code_verifier from KV store (RFC 7636)
-        code_verifier = None
-        tab_id = self._state["tab_id"]
+        code_verifier = self._pop_code_verifier()
+
+        token_response = self._engine_spec.get_oauth2_token(
+            self._oauth2_config,
+            self._parameters["code"],
+            code_verifier=code_verifier,
+        )
+
+        if self._database is None:
+            # Pre-create flow: cache the access token in the KV entry the
+            # initial dance created. The retry of "Test Connection" will read
+            # it via ``get_oauth2_access_token``.
+            self._cache_pre_create_token(token_response)
+            return None
+
+        return self._persist_token(token_response)
+
+    def validate(self) -> None:
+        if error := self._parameters.get("error"):
+            raise OAuth2Error(error)
+
+        self._state = decode_oauth2_state(self._parameters["state"])
+
         try:
-            tab_uuid = UUID(tab_id)
-        except ValueError:
-            tab_uuid = None
-
-        if tab_uuid:
-            kv_value = KeyValueDAO.get_value(
-                resource=KeyValueResource.PKCE_CODE_VERIFIER,
-                key=tab_uuid,
+            self._tab_uuid = UUID(self._state["tab_id"])
+        except (KeyError, ValueError):
+            # Legacy paths may use non-UUID tab ids; we still want to support
+            # them when ``database_id`` is set. The pre-create path below
+            # requires a valid UUID.
+            self._tab_uuid = None
+
+        if database_id := self._state.get("database_id"):
+            self._database = 
DatabaseUserOAuth2TokensDAO.get_database(database_id)
+            if self._database is None:
+                raise DatabaseNotFoundError("Database not found")
+            self._oauth2_config = self._database.get_oauth2_config()
+            self._engine_spec = self._database.db_engine_spec
+        else:
+            if self._tab_uuid is None:
+                raise OAuth2Error(
+                    "Pre-create OAuth2 callback requires a UUID tab_id",
+                )
+            cached = KeyValueDAO.get_value(
+                resource=KeyValueResource.OAUTH2_PRE_CREATE_TOKEN,
+                key=self._tab_uuid,
                 codec=JsonKeyValueCodec(),
             )
-            if kv_value:
-                code_verifier = kv_value.get("code_verifier")
-                KeyValueDAO.delete_entry(KeyValueResource.PKCE_CODE_VERIFIER, 
tab_uuid)
+            if not cached or not cached.get("config"):
+                raise OAuth2Error("Pre-create OAuth2 context not found or 
expired")
+            self._oauth2_config = cast(OAuth2ClientConfig, cached["config"])
+            engine = self._state.get("engine") or cached.get("engine")
+            if not engine:
+                raise OAuth2Error("Pre-create OAuth2 context missing engine 
name")
+            self._engine_spec = get_engine_spec(engine)
 
-        token_response = self._database.db_engine_spec.get_oauth2_token(
-            oauth2_config,
-            self._parameters["code"],
-            code_verifier=code_verifier,
+        if self._oauth2_config is None:
+            raise OAuth2Error("No configuration found for OAuth2")
+
+    def _pop_code_verifier(self) -> str | None:
+        if self._tab_uuid is None:
+            return None
+        kv_value = KeyValueDAO.get_value(
+            resource=KeyValueResource.PKCE_CODE_VERIFIER,
+            key=self._tab_uuid,
+            codec=JsonKeyValueCodec(),
         )
+        if not kv_value:
+            return None
+        KeyValueDAO.delete_entry(KeyValueResource.PKCE_CODE_VERIFIER, 
self._tab_uuid)
+        return kv_value.get("code_verifier")
+
+    def _cache_pre_create_token(self, token_response: OAuth2TokenResponse) -> 
None:
+        self._state = cast(OAuth2State, self._state)
+        self._tab_uuid = cast(UUID, self._tab_uuid)
+        self._oauth2_config = cast(OAuth2ClientConfig, self._oauth2_config)
+        self._engine_spec = cast(type[BaseEngineSpec], self._engine_spec)
+
+        expires_on = datetime.now() + PRE_CREATE_TOKEN_TTL
+        KeyValueDAO.upsert_entry(
+            resource=KeyValueResource.OAUTH2_PRE_CREATE_TOKEN,
+            key=self._tab_uuid,
+            value={
+                "engine": self._engine_spec.engine,
+                "config": self._oauth2_config,
+                "user_id": self._state["user_id"],
+                "access_token": token_response["access_token"],
+            },
+            codec=JsonKeyValueCodec(),
+            expires_on=expires_on,
+        )
+
+    def _persist_token(
+        self,
+        token_response: OAuth2TokenResponse,
+    ) -> DatabaseUserOAuth2Tokens:
+        self._state = cast(OAuth2State, self._state)
 
-        # delete old tokens
         if existing := DatabaseUserOAuth2TokensDAO.find_one_or_none(
             user_id=self._state["user_id"],
             database_id=self._state["database_id"],
         ):
             DatabaseUserOAuth2TokensDAO.delete([existing])
 
-        # store tokens
         expiration = datetime.now() + 
timedelta(seconds=token_response["expires_in"])
         return DatabaseUserOAuth2TokensDAO.create(
             attributes={
@@ -95,16 +186,3 @@ class OAuth2StoreTokenCommand(BaseCommand):
                 "refresh_token": token_response.get("refresh_token"),
             },
         )
-
-    def validate(self) -> None:
-        if error := self._parameters.get("error"):
-            raise OAuth2Error(error)
-
-        self._state = decode_oauth2_state(self._parameters["state"])
-
-        if database := DatabaseUserOAuth2TokensDAO.get_database(
-            self._state["database_id"]
-        ):
-            self._database = database
-        else:
-            raise DatabaseNotFoundError("Database not found")
diff --git a/superset/commands/database/test_connection.py 
b/superset/commands/database/test_connection.py
index 1395994d73e..fafeacc8b58 100644
--- a/superset/commands/database/test_connection.py
+++ b/superset/commands/database/test_connection.py
@@ -17,6 +17,7 @@
 import logging
 from typing import Any, Optional
 
+from flask import g
 from flask_babel import gettext as _
 from sqlalchemy.exc import DBAPIError, NoSuchModuleError
 
@@ -94,6 +95,13 @@ class TestConnectionDatabaseCommand(BaseCommand):
         self.validate()
         ex_str = ""
 
+        # Surface the wizard's tab_id (sent by the frontend) so that
+        # ``get_oauth2_access_token`` can find the pre-create OAuth2 token
+        # cached in the KV store, and so that ``start_oauth2_dance`` reuses
+        # this id instead of generating a new one.
+        if oauth2_tab_id := self._properties.get("oauth2_tab_id"):
+            g.oauth2_tab_id = oauth2_tab_id
+
         url = make_url_safe(self._uri)
         engine_name = url.get_backend_name()
 
diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py
index b399a25ad7c..3acf82924d6 100644
--- a/superset/databases/schemas.py
+++ b/superset/databases/schemas.py
@@ -623,6 +623,16 @@ class 
DatabaseTestConnectionSchema(DatabaseParametersSchemaMixin, Schema):
     )
 
     ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True)
+    oauth2_tab_id = fields.String(
+        metadata={
+            "description": (
+                "UUID identifying the wizard tab for the pre-create OAuth2 
flow."
+                " Optional; when supplied, the engine will look up the 
pre-create"
+                " OAuth2 token cached under this key."
+            ),
+        },
+        allow_none=True,
+    )
 
 
 class TableMetadataOptionsResponse(TypedDict):
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index de9c6c12302..984b4c3fe3b 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -664,7 +664,10 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         # Prevent circular import.
         from superset.daos.key_value import KeyValueDAO
 
-        tab_id = str(uuid4())
+        # Reuse the wizard-supplied tab_id when present so the retry of
+        # "Test Connection" can find the same KV-cached token. Otherwise we
+        # generate a fresh one as before.
+        tab_id = getattr(g, "oauth2_tab_id", None) or str(uuid4())
         default_redirect_uri = get_oauth2_redirect_uri()
 
         # Generate PKCE code verifier (RFC 7636)
@@ -690,6 +693,9 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         # belongs to.
         state: OAuth2State = {
             # Database ID and user ID are the primary key associated with the 
token.
+            # ``database_id`` is ``None`` during the "Create database" wizard 
— the
+            # callback caches the access token in the KV store instead of 
inserting
+            # a row in ``database_user_oauth2_tokens``.
             "database_id": database.id,
             "user_id": g.user.id,
             # In multi-instance deployments there might be a single proxy 
handling
@@ -710,6 +716,25 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         if oauth2_config is None:
             raise OAuth2Error("No configuration found for OAuth2")
 
+        # Pre-create flow: the database has no row yet, so we can't look it up 
by id
+        # in the callback. Stash the engine spec + oauth2 config in the KV 
store
+        # alongside the existing PKCE entry, keyed by the same ``tab_id``. The
+        # callback reads this to exchange the code without a persisted 
``Database``.
+        if database.id is None:
+            state["engine"] = cls.engine
+            
KeyValueDAO.delete_expired_entries(KeyValueResource.OAUTH2_PRE_CREATE_TOKEN)
+            KeyValueDAO.create_entry(
+                resource=KeyValueResource.OAUTH2_PRE_CREATE_TOKEN,
+                value={
+                    "engine": cls.engine,
+                    "config": oauth2_config,
+                },
+                codec=JsonKeyValueCodec(),
+                key=UUID(tab_id),
+                expires_on=datetime.now() + timedelta(minutes=5),
+            )
+            db.session.commit()
+
         oauth_url = cls.get_oauth2_authorization_uri(
             oauth2_config,
             state,
diff --git a/superset/key_value/types.py b/superset/key_value/types.py
index 2cc025e4265..352c6669618 100644
--- a/superset/key_value/types.py
+++ b/superset/key_value/types.py
@@ -46,6 +46,7 @@ class KeyValueResource(StrEnum):
     METASTORE_CACHE = "superset_metastore_cache"
     LOCK = "lock"
     PKCE_CODE_VERIFIER = "pkce_code_verifier"
+    OAUTH2_PRE_CREATE_TOKEN = "oauth2_pre_create_token"  # noqa: S105
     SQLLAB_PERMALINK = "sqllab_permalink"
 
 
diff --git a/superset/superset_typing.py b/superset/superset_typing.py
index 39ff9ab1347..b2f2f6000dd 100644
--- a/superset/superset_typing.py
+++ b/superset/superset_typing.py
@@ -402,7 +402,15 @@ class OAuth2State(TypedDict, total=False):
     Type for the state passed during OAuth2.
     """
 
-    database_id: int
+    # ``database_id`` is ``None`` during the "Create database" wizard, where 
the
+    # OAuth2 dance runs before the database has been persisted. In that case 
the
+    # access token is cached in the KV store keyed by ``tab_id`` until the user
+    # saves the database.
+    database_id: int | None
     user_id: int
     default_redirect_uri: str
     tab_id: str
+    # Engine backend code (e.g. ``"semanticapi"``), present only for pre-create
+    # dances so the callback can resolve the engine spec without a persisted
+    # database row.
+    engine: str
diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py
index 020f5397b3c..d77102235ce 100644
--- a/superset/utils/oauth2.py
+++ b/superset/utils/oauth2.py
@@ -24,10 +24,11 @@ import secrets
 from contextlib import contextmanager
 from datetime import datetime, timedelta, timezone
 from typing import Any, Iterator, TYPE_CHECKING
+from uuid import UUID
 
 import backoff
 import jwt
-from flask import current_app as app, url_for
+from flask import current_app as app, g, url_for
 from marshmallow import EXCLUDE, fields, post_load, Schema, validate
 from werkzeug.routing import BuildError
 
@@ -87,7 +88,7 @@ def generate_code_challenge(code_verifier: str) -> str:
 )
 def get_oauth2_access_token(
     config: OAuth2ClientConfig,
-    database_id: int,
+    database_id: int | None,
     user_id: int,
     db_engine_spec: type[BaseEngineSpec],
 ) -> str | None:
@@ -100,10 +101,18 @@ def get_oauth2_access_token(
     simultaneous requests for refreshing a stale token; in that case only the 
first
     process to acquire the lock will perform the refresh, and othe process 
should find a
     a valid token when they retry.
+
+    When ``database_id`` is ``None`` (the "Create database" wizard, where the
+    database hasn't been persisted yet), look up the token in the KV store 
under
+    ``g.oauth2_tab_id``. The token was cached there by
+    :class:`OAuth2StoreTokenCommand` after the wizard's OAuth2 dance.
     """  # noqa: E501
     # pylint: disable=import-outside-toplevel
     from superset.models.core import DatabaseUserOAuth2Tokens
 
+    if not database_id:
+        return _get_pre_create_access_token(user_id)
+
     token = (
         db.session.query(DatabaseUserOAuth2Tokens)
         .filter_by(user_id=user_id, database_id=database_id)
@@ -124,6 +133,39 @@ def get_oauth2_access_token(
     return None
 
 
+def _get_pre_create_access_token(user_id: int) -> str | None:
+    """
+    Look up a pre-create OAuth2 token in the KV store.
+
+    The KV entry is keyed by the wizard's ``tab_id``; the frontend passes it on
+    every "Test Connection" request, and the test-connection command exposes it
+    via :data:`flask.g.oauth2_tab_id`.
+    """
+    # pylint: disable=import-outside-toplevel
+    from superset.daos.key_value import KeyValueDAO
+    from superset.key_value.types import JsonKeyValueCodec, KeyValueResource
+
+    tab_id = getattr(g, "oauth2_tab_id", None)
+    if not tab_id:
+        return None
+    try:
+        tab_uuid = UUID(tab_id)
+    except (TypeError, ValueError):
+        return None
+
+    cached = KeyValueDAO.get_value(
+        resource=KeyValueResource.OAUTH2_PRE_CREATE_TOKEN,
+        key=tab_uuid,
+        codec=JsonKeyValueCodec(),
+    )
+    if not cached:
+        return None
+    if cached.get("user_id") != user_id:
+        # the tab_id is the user's own, but be defensive
+        return None
+    return cached.get("access_token")
+
+
 def refresh_oauth2_token(
     config: OAuth2ClientConfig,
     database_id: int,
@@ -203,14 +245,20 @@ def refresh_oauth2_token(
 def encode_oauth2_state(state: OAuth2State) -> str:
     """
     Encode the OAuth2 state.
+
+    ``database_id`` is ``None`` for the pre-create-database OAuth2 dance,
+    which means the token won't be persisted to ``database_user_oauth2_tokens``
+    when the callback runs — see :class:`OAuth2StoreTokenCommand`.
     """
     payload: dict[str, Any] = {
         "exp": datetime.now(tz=timezone.utc) + JWT_EXPIRATION,
-        "database_id": state["database_id"],
+        "database_id": state.get("database_id"),
         "user_id": state["user_id"],
         "default_redirect_uri": state["default_redirect_uri"],
         "tab_id": state["tab_id"],
     }
+    if engine := state.get("engine"):
+        payload["engine"] = engine
 
     encoded_state = jwt.encode(
         payload=payload,
@@ -225,10 +273,11 @@ def encode_oauth2_state(state: OAuth2State) -> str:
 
 
 class OAuth2StateSchema(Schema):
-    database_id = fields.Int(required=True)
+    database_id = fields.Int(required=True, allow_none=True)
     user_id = fields.Int(required=True)
     default_redirect_uri = fields.Str(required=True)
     tab_id = fields.Str(required=True)
+    engine = fields.Str(required=False, load_default=None)
 
     # pylint: disable=unused-argument
     @post_load
@@ -237,12 +286,15 @@ class OAuth2StateSchema(Schema):
         data: dict[str, Any],
         **kwargs: Any,
     ) -> OAuth2State:
-        return {
+        state: OAuth2State = {
             "database_id": data["database_id"],
             "user_id": data["user_id"],
             "default_redirect_uri": data["default_redirect_uri"],
             "tab_id": data["tab_id"],
         }
+        if data.get("engine"):
+            state["engine"] = data["engine"]
+        return state
 
     class Meta:  # pylint: disable=too-few-public-methods
         # ignore `exp`
diff --git a/tests/unit_tests/commands/databases/oauth2_test.py 
b/tests/unit_tests/commands/databases/oauth2_test.py
index 0fbe2035d29..00b217cae09 100644
--- a/tests/unit_tests/commands/databases/oauth2_test.py
+++ b/tests/unit_tests/commands/databases/oauth2_test.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from typing import Any
+from typing import Any, cast
 from unittest.mock import MagicMock
 
 import pytest
@@ -27,6 +27,7 @@ from superset.daos.database import DatabaseUserOAuth2TokensDAO
 from superset.databases.schemas import OAuth2ProviderResponseSchema
 from superset.exceptions import OAuth2Error
 from superset.models.core import Database
+from superset.superset_typing import OAuth2State
 from superset.utils.oauth2 import decode_oauth2_state, encode_oauth2_state
 
 
@@ -135,6 +136,80 @@ def test_run_success(
     mock_create.assert_called_once()
 
 
+def test_run_pre_create_caches_token_in_kv(mocker: MockerFixture) -> None:
+    """
+    With ``state.database_id is None`` the command reads the engine + config
+    from the pre-create KV entry, exchanges the code, and writes the token
+    back to KV — *not* to ``database_user_oauth2_tokens``.
+    """
+    state: OAuth2State = {
+        "user_id": 1,
+        "database_id": None,
+        "default_redirect_uri": "http://localhost:8088/api/v1/oauth2/";,
+        "tab_id": "3a3a3a3a-3a3a-3a3a-3a3a-3a3a3a3a3a3a",
+        "engine": "semanticapi",
+    }
+    parameters: dict[str, Any] = {
+        "code": "the-code",
+        "state": encode_oauth2_state(state),
+    }
+
+    kv_dao = mocker.patch("superset.commands.database.oauth2.KeyValueDAO")
+    kv_dao.get_value.side_effect = [
+        {"engine": "semanticapi", "config": {"id": "x", "secret": "y"}},  # 
validate()
+        None,  # code_verifier lookup
+    ]
+
+    engine_spec = mocker.MagicMock()
+    engine_spec.engine = "semanticapi"
+    engine_spec.get_oauth2_token.return_value = {
+        "access_token": "fresh-token",
+        "expires_in": 3600,
+        "refresh_token": "refresh",
+    }
+    mocker.patch(
+        "superset.commands.database.oauth2.get_engine_spec",
+        return_value=engine_spec,
+    )
+
+    dao_get_db = mocker.patch.object(DatabaseUserOAuth2TokensDAO, 
"get_database")
+    dao_create = mocker.patch.object(DatabaseUserOAuth2TokensDAO, "create")
+
+    result = OAuth2StoreTokenCommand(
+        cast(OAuth2ProviderResponseSchema, parameters),
+    ).run()
+
+    assert result is None
+    dao_get_db.assert_not_called()
+    dao_create.assert_not_called()
+    kv_dao.upsert_entry.assert_called_once()
+    upsert_kwargs = kv_dao.upsert_entry.call_args.kwargs
+    assert upsert_kwargs["value"]["access_token"] == "fresh-token"  # noqa: 
S105
+    assert upsert_kwargs["value"]["user_id"] == 1
+
+
+def test_run_pre_create_missing_kv_entry(mocker: MockerFixture) -> None:
+    """
+    Pre-create flow with no cached entry should fail with a clear OAuth2Error.
+    """
+    state: OAuth2State = {
+        "user_id": 1,
+        "database_id": None,
+        "default_redirect_uri": "http://localhost:8088/api/v1/oauth2/";,
+        "tab_id": "3a3a3a3a-3a3a-3a3a-3a3a-3a3a3a3a3a3a",
+    }
+    parameters: dict[str, Any] = {"code": "x", "state": 
encode_oauth2_state(state)}
+
+    kv_dao = mocker.patch("superset.commands.database.oauth2.KeyValueDAO")
+    kv_dao.get_value.return_value = None
+
+    with pytest.raises(OAuth2Error) as exc_info:
+        OAuth2StoreTokenCommand(
+            cast(OAuth2ProviderResponseSchema, parameters),
+        ).validate()
+    assert "Pre-create OAuth2 context" in exc_info.value.error.extra["error"]
+
+
 def test_run_existing_token(
     mocker: MockerFixture,
     mock_database: MagicMock,
diff --git a/tests/unit_tests/utils/oauth2_tests.py 
b/tests/unit_tests/utils/oauth2_tests.py
index ac788ce66e5..a8bddf6c3cb 100644
--- a/tests/unit_tests/utils/oauth2_tests.py
+++ b/tests/unit_tests/utils/oauth2_tests.py
@@ -458,6 +458,83 @@ def test_encode_decode_oauth2_state(
     assert decoded["user_id"] == 2
 
 
+def test_encode_decode_oauth2_state_pre_create(mocker: MockerFixture) -> None:
+    """
+    The pre-create dance encodes ``database_id=None`` and an ``engine``
+    field; both must survive the round-trip so the callback can resolve the
+    engine spec without a persisted database.
+    """
+    from superset.superset_typing import OAuth2State
+
+    mocker.patch(
+        "flask.current_app.config",
+        {"SECRET_KEY": "test-secret-key", "DATABASE_OAUTH2_JWT_ALGORITHM": 
"HS256"},
+    )
+
+    state: OAuth2State = {
+        "database_id": None,
+        "user_id": 2,
+        "default_redirect_uri": "http://localhost:8088/api/v1/oauth2/";,
+        "tab_id": "abc",
+        "engine": "semanticapi",
+    }
+    with freeze_time("2024-01-01"):
+        decoded = decode_oauth2_state(encode_oauth2_state(state))
+
+    assert decoded["database_id"] is None
+    assert decoded["engine"] == "semanticapi"
+
+
+def test_get_oauth2_access_token_pre_create_hit(mocker: MockerFixture) -> None:
+    """
+    With ``database_id=None`` and a matching KV entry, the cached pre-create
+    token is returned.
+    """
+    mocker.patch(
+        "superset.utils.oauth2.g",
+        oauth2_tab_id="3a3a3a3a-3a3a-3a3a-3a3a-3a3a3a3a3a3a",
+    )
+    kv_dao = mocker.patch("superset.daos.key_value.KeyValueDAO")
+    kv_dao.get_value.return_value = {
+        "access_token": "cached-token",
+        "user_id": 7,
+    }
+
+    assert (
+        get_oauth2_access_token(DUMMY_OAUTH2_CONFIG, None, 7, 
mocker.MagicMock())
+        == "cached-token"
+    )
+
+
+def test_get_oauth2_access_token_pre_create_miss(mocker: MockerFixture) -> 
None:
+    """
+    Without a tab id (or with no KV entry / wrong user) the lookup returns 
None.
+    """
+    mocker.patch("superset.utils.oauth2.g", oauth2_tab_id=None)
+    assert (
+        get_oauth2_access_token(DUMMY_OAUTH2_CONFIG, None, 7, 
mocker.MagicMock())
+        is None
+    )
+
+    mocker.patch(
+        "superset.utils.oauth2.g",
+        oauth2_tab_id="3a3a3a3a-3a3a-3a3a-3a3a-3a3a3a3a3a3a",
+    )
+    kv_dao = mocker.patch("superset.daos.key_value.KeyValueDAO")
+    kv_dao.get_value.return_value = None
+    assert (
+        get_oauth2_access_token(DUMMY_OAUTH2_CONFIG, None, 7, 
mocker.MagicMock())
+        is None
+    )
+
+    kv_dao.get_value.return_value = {"access_token": "x", "user_id": 99}
+    # token belongs to another user
+    assert (
+        get_oauth2_access_token(DUMMY_OAUTH2_CONFIG, None, 7, 
mocker.MagicMock())
+        is None
+    )
+
+
 def test_get_oauth2_access_token_lock_not_acquired_no_error_log(
     mocker: MockerFixture,
     caplog: pytest.LogCaptureFixture,

Reply via email to