This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 8695239372 feat: `OAuth2StoreTokenCommand` (#32546)
8695239372 is described below

commit 86952393728324041b18afc1e3d7fbc69ea32c21
Author: Beto Dealmeida <[email protected]>
AuthorDate: Thu Mar 13 09:45:24 2025 -0400

    feat: `OAuth2StoreTokenCommand` (#32546)
---
 superset/commands/database/oauth2.py               |  88 +++++++++++
 superset/daos/database.py                          |  11 ++
 superset/databases/api.py                          |  50 +-----
 tests/unit_tests/commands/databases/oauth2_test.py | 168 +++++++++++++++++++++
 tests/unit_tests/databases/api_test.py             |  34 ++---
 5 files changed, 290 insertions(+), 61 deletions(-)

diff --git a/superset/commands/database/oauth2.py 
b/superset/commands/database/oauth2.py
new file mode 100644
index 0000000000..f7259077bc
--- /dev/null
+++ b/superset/commands/database/oauth2.py
@@ -0,0 +1,88 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from datetime import datetime, timedelta
+from functools import partial
+from typing import cast
+
+from superset.commands.base import BaseCommand
+from superset.commands.database.exceptions import DatabaseNotFoundError
+from superset.daos.database import DatabaseUserOAuth2TokensDAO
+from superset.databases.schemas import OAuth2ProviderResponseSchema
+from superset.exceptions import OAuth2Error
+from superset.models.core import Database, DatabaseUserOAuth2Tokens
+from superset.superset_typing import OAuth2State
+from superset.utils.decorators import on_error, transaction
+from superset.utils.oauth2 import decode_oauth2_state
+
+
+class OAuth2StoreTokenCommand(BaseCommand):
+    """
+    Command to store OAuth2 tokens in the database.
+    """
+
+    def __init__(self, parameters: OAuth2ProviderResponseSchema):
+        self._parameters = parameters
+        self._state: OAuth2State | None = None
+        self._database: Database | None = None
+
+    @transaction(on_error=partial(on_error, reraise=OAuth2Error))
+    def run(self) -> DatabaseUserOAuth2Tokens:
+        self.validate()
+        self._database = cast(Database, self._database)
+        self._state = cast(OAuth2State, self._state)
+
+        oauth2_config = self._database.get_oauth2_config()
+        if oauth2_config is None:
+            raise OAuth2Error("No configuration found for OAuth2")
+
+        token_response = self._database.db_engine_spec.get_oauth2_token(
+            oauth2_config,
+            self._parameters["code"],
+        )
+
+        # delete old tokens
+        if existing := DatabaseUserOAuth2TokensDAO.find_one_or_none(
+            user_id=self._state["user_id"],
+            database_id=self._state["database_id"],
+        ):
+            DatabaseUserOAuth2TokensDAO.delete([existing])
+
+        # store tokens
+        expiration = datetime.now() + 
timedelta(seconds=token_response["expires_in"])
+        return DatabaseUserOAuth2TokensDAO.create(
+            attributes={
+                "user_id": self._state["user_id"],
+                "database_id": self._state["database_id"],
+                "access_token": token_response["access_token"],
+                "access_token_expiration": expiration,
+                "refresh_token": token_response.get("refresh_token"),
+            },
+        )
+
+    def validate(self) -> None:
+        if error := self._parameters.get("error"):
+            raise OAuth2Error(error)
+
+        self._state = decode_oauth2_state(self._parameters["state"])
+
+        if database := DatabaseUserOAuth2TokensDAO.get_database(
+            self._state["database_id"]
+        ):
+            self._database = database
+        else:
+            raise DatabaseNotFoundError("Database not found")
diff --git a/superset/daos/database.py b/superset/daos/database.py
index 09b2fedf93..fa035534ee 100644
--- a/superset/daos/database.py
+++ b/superset/daos/database.py
@@ -195,3 +195,14 @@ class 
DatabaseUserOAuth2TokensDAO(BaseDAO[DatabaseUserOAuth2Tokens]):
     """
     DAO for OAuth2 tokens.
     """
+
+    @classmethod
+    def get_database(cls, database_id: int) -> Database | None:
+        """
+        Returns the database.
+
+        Note that this is different from `DatabaseDAO.find_by_id(database_id)` 
because
+        this DAO doesn't have any filters, so it can be called even for users 
without
+        database access (which is necessary for OAuth2).
+        """
+        return 
db.session.query(Database).filter_by(id=database_id).one_or_none()
diff --git a/superset/databases/api.py b/superset/databases/api.py
index b4196b403c..5c3e024e73 100644
--- a/superset/databases/api.py
+++ b/superset/databases/api.py
@@ -19,7 +19,7 @@
 from __future__ import annotations
 
 import logging
-from datetime import datetime, timedelta
+from datetime import datetime
 from io import BytesIO
 from typing import Any, cast
 from zipfile import is_zipfile, ZipFile
@@ -46,6 +46,7 @@ from superset.commands.database.exceptions import (
 )
 from superset.commands.database.export import ExportDatabasesCommand
 from superset.commands.database.importers.dispatcher import 
ImportDatabasesCommand
+from superset.commands.database.oauth2 import OAuth2StoreTokenCommand
 from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand
 from superset.commands.database.ssh_tunnel.exceptions import (
     SSHTunnelDatabasePortError,
@@ -72,7 +73,7 @@ from superset.commands.importers.exceptions import (
 )
 from superset.commands.importers.v1.utils import get_contents_from_bundle
 from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod
-from superset.daos.database import DatabaseDAO, DatabaseUserOAuth2TokensDAO
+from superset.daos.database import DatabaseDAO
 from superset.databases.decorators import check_table_access
 from superset.databases.filters import DatabaseFilter, 
DatabaseUploadEnabledFilter
 from superset.databases.schemas import (
@@ -109,7 +110,6 @@ from superset.errors import ErrorLevel, SupersetError, 
SupersetErrorType
 from superset.exceptions import (
     DatabaseNotFoundException,
     InvalidPayloadSchemaError,
-    OAuth2Error,
     OAuth2RedirectError,
     SupersetErrorsException,
     SupersetException,
@@ -1433,51 +1433,15 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
               $ref: '#/components/responses/500'
         """
         parameters = OAuth2ProviderResponseSchema().load(request.args)
+        command = OAuth2StoreTokenCommand(parameters)
+        command.run()
 
-        if "error" in parameters:
-            raise OAuth2Error(parameters["error"])
-
-        # note that when decoding the state we will perform JWT validation, 
preventing a
-        # malicious payload that would insert a bogus database token, or 
delete an
-        # existing one.
         state = decode_oauth2_state(parameters["state"])
+        tab_id = state["tab_id"]
 
-        # exchange code for access/refresh tokens
-        database = DatabaseDAO.find_by_id(state["database_id"], 
skip_base_filter=True)
-        if database is None:
-            return self.response_404()
-
-        oauth2_config = database.get_oauth2_config()
-        if oauth2_config is None:
-            raise OAuth2Error("No configuration found for OAuth2")
-
-        token_response = database.db_engine_spec.get_oauth2_token(
-            oauth2_config,
-            parameters["code"],
-        )
-
-        # delete old tokens
-        existing = DatabaseUserOAuth2TokensDAO.find_one_or_none(
-            user_id=state["user_id"],
-            database_id=state["database_id"],
-        )
-        if existing:
-            DatabaseUserOAuth2TokensDAO.delete([existing])
-
-        # store tokens
-        expiration = datetime.now() + 
timedelta(seconds=token_response["expires_in"])
-        DatabaseUserOAuth2TokensDAO.create(
-            attributes={
-                "user_id": state["user_id"],
-                "database_id": state["database_id"],
-                "access_token": token_response["access_token"],
-                "access_token_expiration": expiration,
-                "refresh_token": token_response.get("refresh_token"),
-            },
-        )
         # return blank page that closes itself
         return make_response(
-            render_template("superset/oauth2.html", tab_id=state["tab_id"]),
+            render_template("superset/oauth2.html", tab_id=tab_id),
             200,
         )
 
diff --git a/tests/unit_tests/commands/databases/oauth2_test.py 
b/tests/unit_tests/commands/databases/oauth2_test.py
new file mode 100644
index 0000000000..0fbe2035d2
--- /dev/null
+++ b/tests/unit_tests/commands/databases/oauth2_test.py
@@ -0,0 +1,168 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Any
+from unittest.mock import MagicMock
+
+import pytest
+from pytest_mock import MockerFixture
+
+from superset.commands.database.exceptions import DatabaseNotFoundError
+from superset.commands.database.oauth2 import OAuth2StoreTokenCommand
+from superset.daos.database import DatabaseUserOAuth2TokensDAO
+from superset.databases.schemas import OAuth2ProviderResponseSchema
+from superset.exceptions import OAuth2Error
+from superset.models.core import Database
+from superset.utils.oauth2 import decode_oauth2_state, encode_oauth2_state
+
+
[email protected]
+def mock_database(mocker: MockerFixture) -> MagicMock:
+    database = mocker.MagicMock(spec=Database)
+    database.get_oauth2_config.return_value = {
+        "client_id": "test",
+        "client_secret": "secret",
+    }
+    database.db_engine_spec.get_oauth2_token.return_value = {
+        "access_token": "test_access_token",
+        "expires_in": 3600,
+        "refresh_token": "test_refresh_token",
+    }
+    return database
+
+
[email protected]
+def mock_state() -> str:
+    return encode_oauth2_state(
+        {
+            "user_id": 1,
+            "database_id": 123,
+            "default_redirect_uri": "http://localhost:8088/api/v1/oauth2/";,
+            "tab_id": "1234",
+        }
+    )
+
+
[email protected]
+def mock_parameters(mock_state: str) -> dict[str, Any]:
+    return {"code": "test_code", "state": mock_state}
+
+
+def test_validate_success(
+    mocker: MockerFixture,
+    mock_database: MagicMock,
+    mock_state: str,
+    mock_parameters: OAuth2ProviderResponseSchema,
+) -> None:
+    mocker.patch("superset.utils.oauth2.decode_oauth2_state", 
return_value=mock_state)
+    mocker.patch.object(
+        DatabaseUserOAuth2TokensDAO,
+        "get_database",
+        return_value=mock_database,
+    )
+
+    command = OAuth2StoreTokenCommand(mock_parameters)
+    command.validate()
+
+    assert command._database == mock_database
+    assert command._state == decode_oauth2_state(mock_state)
+
+
+def test_validate_database_not_found(
+    mocker: MockerFixture,
+    mock_parameters: OAuth2ProviderResponseSchema,
+) -> None:
+    mocker.patch(
+        "superset.utils.oauth2.decode_oauth2_state",
+        return_value={"database_id": 999},
+    )
+    mocker.patch.object(DatabaseUserOAuth2TokensDAO, "get_database", 
return_value=None)
+
+    command = OAuth2StoreTokenCommand(mock_parameters)
+    with pytest.raises(DatabaseNotFoundError, match="Database not found"):
+        command.validate()
+
+
+def test_validate_oauth2_error(mock_parameters: OAuth2ProviderResponseSchema) 
-> None:
+    mock_parameters["error"] = "OAuth2 failure"
+    command = OAuth2StoreTokenCommand(mock_parameters)
+    with pytest.raises(OAuth2Error, match="Something went wrong while doing 
OAuth2"):
+        command.validate()
+
+
+def test_run_success(
+    mocker: MockerFixture,
+    mock_database: MagicMock,
+    mock_state: str,
+    mock_parameters: OAuth2ProviderResponseSchema,
+) -> None:
+    mocker.patch.object(
+        DatabaseUserOAuth2TokensDAO,
+        "get_database",
+        return_value=mock_database,
+    )
+    mocker.patch.object(
+        DatabaseUserOAuth2TokensDAO,
+        "find_one_or_none",
+        return_value=None,
+    )
+    mocker.patch.object(DatabaseUserOAuth2TokensDAO, "delete")
+    mock_create = mocker.patch.object(
+        DatabaseUserOAuth2TokensDAO,
+        "create",
+        return_value="new_token",
+    )
+    mocker.patch("superset.utils.oauth2.decode_oauth2_state", 
return_value=mock_state)
+
+    command = OAuth2StoreTokenCommand(mock_parameters)
+    result = command.run()
+
+    assert result == "new_token"
+    mock_create.assert_called_once()
+
+
+def test_run_existing_token(
+    mocker: MockerFixture,
+    mock_database: MagicMock,
+    mock_state: str,
+    mock_parameters: OAuth2ProviderResponseSchema,
+) -> None:
+    mocker.patch.object(
+        DatabaseUserOAuth2TokensDAO,
+        "get_database",
+        return_value=mock_database,
+    )
+    existing_token = MagicMock()
+    mocker.patch.object(
+        DatabaseUserOAuth2TokensDAO,
+        "find_one_or_none",
+        return_value=existing_token,
+    )
+    mock_delete = mocker.patch.object(DatabaseUserOAuth2TokensDAO, "delete")
+    mock_create = mocker.patch.object(
+        DatabaseUserOAuth2TokensDAO,
+        "create",
+        return_value="new_token",
+    )
+    mocker.patch("superset.utils.oauth2.decode_oauth2_state", 
return_value=mock_state)
+
+    command = OAuth2StoreTokenCommand(mock_parameters)
+    result = command.run()
+
+    assert result == "new_token"
+    mock_delete.assert_called_once_with([existing_token])
+    mock_create.assert_called_once()
diff --git a/tests/unit_tests/databases/api_test.py 
b/tests/unit_tests/databases/api_test.py
index 64d99638de..f7a3da00b1 100644
--- a/tests/unit_tests/databases/api_test.py
+++ b/tests/unit_tests/databases/api_test.py
@@ -41,7 +41,9 @@ from superset.db_engine_specs.sqlite import SqliteEngineSpec
 from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
 from superset.exceptions import OAuth2RedirectError, SupersetSecurityException
 from superset.sql_parse import Table
+from superset.superset_typing import OAuth2State
 from superset.utils import json
+from superset.utils.oauth2 import encode_oauth2_state
 from tests.unit_tests.fixtures.common import (
     create_columnar_file,
     create_csv_file,
@@ -752,6 +754,7 @@ def test_oauth2_happy_path(
     Database.metadata.create_all(session.get_bind())  # pylint: 
disable=no-member
     db.session.add(
         Database(
+            id=1,
             database_name="my_db",
             sqlalchemy_uri="sqlite://",
             uuid=UUID("7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb"),
@@ -771,13 +774,12 @@ def test_oauth2_happy_path(
         "refresh_token": "ZZZ",
     }
 
-    state = {
+    state: OAuth2State = {
         "user_id": 1,
         "database_id": 1,
-        "tab_id": 42,
+        "tab_id": "42",
+        "default_redirect_uri": "http://localhost:8088/api/v1/oauth2/";,
     }
-    decode_oauth2_state = 
mocker.patch("superset.databases.api.decode_oauth2_state")
-    decode_oauth2_state.return_value = state
 
     mocker.patch("superset.databases.api.render_template", return_value="OK")
 
@@ -785,13 +787,12 @@ def test_oauth2_happy_path(
         response = client.get(
             "/api/v1/database/oauth2/",
             query_string={
-                "state": "some%2Estate",
+                "state": encode_oauth2_state(state),
                 "code": "XXX",
             },
         )
 
     assert response.status_code == 200
-    decode_oauth2_state.assert_called_with("some%2Estate")
     get_oauth2_token.assert_called_with({"id": "one", "secret": "two"}, "XXX")
 
     token = db.session.query(DatabaseUserOAuth2Tokens).one()
@@ -841,13 +842,12 @@ def test_oauth2_permissions(
         "refresh_token": "ZZZ",
     }
 
-    state = {
+    state: OAuth2State = {
         "user_id": 1,
         "database_id": 1,
-        "tab_id": 42,
+        "tab_id": "42",
+        "default_redirect_uri": "http://localhost:8088/api/v1/oauth2/";,
     }
-    decode_oauth2_state = 
mocker.patch("superset.databases.api.decode_oauth2_state")
-    decode_oauth2_state.return_value = state
 
     mocker.patch("superset.databases.api.render_template", return_value="OK")
 
@@ -855,13 +855,12 @@ def test_oauth2_permissions(
         response = client.get(
             "/api/v1/database/oauth2/",
             query_string={
-                "state": "some%2Estate",
+                "state": encode_oauth2_state(state),
                 "code": "XXX",
             },
         )
 
     assert response.status_code == 200
-    decode_oauth2_state.assert_called_with("some%2Estate")
     get_oauth2_token.assert_called_with({"id": "one", "secret": "two"}, "XXX")
 
     token = db.session.query(DatabaseUserOAuth2Tokens).one()
@@ -916,13 +915,12 @@ def test_oauth2_multiple_tokens(
         },
     ]
 
-    state = {
+    state: OAuth2State = {
         "user_id": 1,
         "database_id": 1,
-        "tab_id": 42,
+        "tab_id": "42",
+        "default_redirect_uri": "http://localhost:8088/api/v1/oauth2/";,
     }
-    decode_oauth2_state = 
mocker.patch("superset.databases.api.decode_oauth2_state")
-    decode_oauth2_state.return_value = state
 
     mocker.patch("superset.databases.api.render_template", return_value="OK")
 
@@ -930,7 +928,7 @@ def test_oauth2_multiple_tokens(
         response = client.get(
             "/api/v1/database/oauth2/",
             query_string={
-                "state": "some%2Estate",
+                "state": encode_oauth2_state(state),
                 "code": "XXX",
             },
         )
@@ -939,7 +937,7 @@ def test_oauth2_multiple_tokens(
         response = client.get(
             "/api/v1/database/oauth2/",
             query_string={
-                "state": "some%2Estate",
+                "state": encode_oauth2_state(state),
                 "code": "XXX",
             },
         )

Reply via email to