This is an automated email from the ASF dual-hosted git repository.
beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 8695239372 feat: `OAuth2StoreTokenCommand` (#32546)
8695239372 is described below
commit 86952393728324041b18afc1e3d7fbc69ea32c21
Author: Beto Dealmeida <[email protected]>
AuthorDate: Thu Mar 13 09:45:24 2025 -0400
feat: `OAuth2StoreTokenCommand` (#32546)
---
superset/commands/database/oauth2.py | 88 +++++++++++
superset/daos/database.py | 11 ++
superset/databases/api.py | 50 +-----
tests/unit_tests/commands/databases/oauth2_test.py | 168 +++++++++++++++++++++
tests/unit_tests/databases/api_test.py | 34 ++---
5 files changed, 290 insertions(+), 61 deletions(-)
diff --git a/superset/commands/database/oauth2.py
b/superset/commands/database/oauth2.py
new file mode 100644
index 0000000000..f7259077bc
--- /dev/null
+++ b/superset/commands/database/oauth2.py
@@ -0,0 +1,88 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from datetime import datetime, timedelta
+from functools import partial
+from typing import cast
+
+from superset.commands.base import BaseCommand
+from superset.commands.database.exceptions import DatabaseNotFoundError
+from superset.daos.database import DatabaseUserOAuth2TokensDAO
+from superset.databases.schemas import OAuth2ProviderResponseSchema
+from superset.exceptions import OAuth2Error
+from superset.models.core import Database, DatabaseUserOAuth2Tokens
+from superset.superset_typing import OAuth2State
+from superset.utils.decorators import on_error, transaction
+from superset.utils.oauth2 import decode_oauth2_state
+
+
+class OAuth2StoreTokenCommand(BaseCommand):
+ """
+ Command to store OAuth2 tokens in the database.
+ """
+
+ def __init__(self, parameters: OAuth2ProviderResponseSchema):
+ self._parameters = parameters
+ self._state: OAuth2State | None = None
+ self._database: Database | None = None
+
+ @transaction(on_error=partial(on_error, reraise=OAuth2Error))
+ def run(self) -> DatabaseUserOAuth2Tokens:
+ self.validate()
+ self._database = cast(Database, self._database)
+ self._state = cast(OAuth2State, self._state)
+
+ oauth2_config = self._database.get_oauth2_config()
+ if oauth2_config is None:
+ raise OAuth2Error("No configuration found for OAuth2")
+
+ token_response = self._database.db_engine_spec.get_oauth2_token(
+ oauth2_config,
+ self._parameters["code"],
+ )
+
+ # delete old tokens
+ if existing := DatabaseUserOAuth2TokensDAO.find_one_or_none(
+ user_id=self._state["user_id"],
+ database_id=self._state["database_id"],
+ ):
+ DatabaseUserOAuth2TokensDAO.delete([existing])
+
+ # store tokens
+ expiration = datetime.now() +
timedelta(seconds=token_response["expires_in"])
+ return DatabaseUserOAuth2TokensDAO.create(
+ attributes={
+ "user_id": self._state["user_id"],
+ "database_id": self._state["database_id"],
+ "access_token": token_response["access_token"],
+ "access_token_expiration": expiration,
+ "refresh_token": token_response.get("refresh_token"),
+ },
+ )
+
+ def validate(self) -> None:
+ if error := self._parameters.get("error"):
+ raise OAuth2Error(error)
+
+ self._state = decode_oauth2_state(self._parameters["state"])
+
+ if database := DatabaseUserOAuth2TokensDAO.get_database(
+ self._state["database_id"]
+ ):
+ self._database = database
+ else:
+ raise DatabaseNotFoundError("Database not found")
diff --git a/superset/daos/database.py b/superset/daos/database.py
index 09b2fedf93..fa035534ee 100644
--- a/superset/daos/database.py
+++ b/superset/daos/database.py
@@ -195,3 +195,14 @@ class
DatabaseUserOAuth2TokensDAO(BaseDAO[DatabaseUserOAuth2Tokens]):
"""
DAO for OAuth2 tokens.
"""
+
+ @classmethod
+ def get_database(cls, database_id: int) -> Database | None:
+ """
+ Returns the database.
+
+ Note that this is different from `DatabaseDAO.find_by_id(database_id)`
because
+ this DAO doesn't have any filters, so it can be called even for users
without
+ database access (which is necessary for OAuth2).
+ """
+ return
db.session.query(Database).filter_by(id=database_id).one_or_none()
diff --git a/superset/databases/api.py b/superset/databases/api.py
index b4196b403c..5c3e024e73 100644
--- a/superset/databases/api.py
+++ b/superset/databases/api.py
@@ -19,7 +19,7 @@
from __future__ import annotations
import logging
-from datetime import datetime, timedelta
+from datetime import datetime
from io import BytesIO
from typing import Any, cast
from zipfile import is_zipfile, ZipFile
@@ -46,6 +46,7 @@ from superset.commands.database.exceptions import (
)
from superset.commands.database.export import ExportDatabasesCommand
from superset.commands.database.importers.dispatcher import
ImportDatabasesCommand
+from superset.commands.database.oauth2 import OAuth2StoreTokenCommand
from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand
from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelDatabasePortError,
@@ -72,7 +73,7 @@ from superset.commands.importers.exceptions import (
)
from superset.commands.importers.v1.utils import get_contents_from_bundle
from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod
-from superset.daos.database import DatabaseDAO, DatabaseUserOAuth2TokensDAO
+from superset.daos.database import DatabaseDAO
from superset.databases.decorators import check_table_access
from superset.databases.filters import DatabaseFilter,
DatabaseUploadEnabledFilter
from superset.databases.schemas import (
@@ -109,7 +110,6 @@ from superset.errors import ErrorLevel, SupersetError,
SupersetErrorType
from superset.exceptions import (
DatabaseNotFoundException,
InvalidPayloadSchemaError,
- OAuth2Error,
OAuth2RedirectError,
SupersetErrorsException,
SupersetException,
@@ -1433,51 +1433,15 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
$ref: '#/components/responses/500'
"""
parameters = OAuth2ProviderResponseSchema().load(request.args)
+ command = OAuth2StoreTokenCommand(parameters)
+ command.run()
- if "error" in parameters:
- raise OAuth2Error(parameters["error"])
-
- # note that when decoding the state we will perform JWT validation,
preventing a
- # malicious payload that would insert a bogus database token, or
delete an
- # existing one.
state = decode_oauth2_state(parameters["state"])
+ tab_id = state["tab_id"]
- # exchange code for access/refresh tokens
- database = DatabaseDAO.find_by_id(state["database_id"],
skip_base_filter=True)
- if database is None:
- return self.response_404()
-
- oauth2_config = database.get_oauth2_config()
- if oauth2_config is None:
- raise OAuth2Error("No configuration found for OAuth2")
-
- token_response = database.db_engine_spec.get_oauth2_token(
- oauth2_config,
- parameters["code"],
- )
-
- # delete old tokens
- existing = DatabaseUserOAuth2TokensDAO.find_one_or_none(
- user_id=state["user_id"],
- database_id=state["database_id"],
- )
- if existing:
- DatabaseUserOAuth2TokensDAO.delete([existing])
-
- # store tokens
- expiration = datetime.now() +
timedelta(seconds=token_response["expires_in"])
- DatabaseUserOAuth2TokensDAO.create(
- attributes={
- "user_id": state["user_id"],
- "database_id": state["database_id"],
- "access_token": token_response["access_token"],
- "access_token_expiration": expiration,
- "refresh_token": token_response.get("refresh_token"),
- },
- )
# return blank page that closes itself
return make_response(
- render_template("superset/oauth2.html", tab_id=state["tab_id"]),
+ render_template("superset/oauth2.html", tab_id=tab_id),
200,
)
diff --git a/tests/unit_tests/commands/databases/oauth2_test.py
b/tests/unit_tests/commands/databases/oauth2_test.py
new file mode 100644
index 0000000000..0fbe2035d2
--- /dev/null
+++ b/tests/unit_tests/commands/databases/oauth2_test.py
@@ -0,0 +1,168 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Any
+from unittest.mock import MagicMock
+
+import pytest
+from pytest_mock import MockerFixture
+
+from superset.commands.database.exceptions import DatabaseNotFoundError
+from superset.commands.database.oauth2 import OAuth2StoreTokenCommand
+from superset.daos.database import DatabaseUserOAuth2TokensDAO
+from superset.databases.schemas import OAuth2ProviderResponseSchema
+from superset.exceptions import OAuth2Error
+from superset.models.core import Database
+from superset.utils.oauth2 import decode_oauth2_state, encode_oauth2_state
+
+
[email protected]
+def mock_database(mocker: MockerFixture) -> MagicMock:
+ database = mocker.MagicMock(spec=Database)
+ database.get_oauth2_config.return_value = {
+ "client_id": "test",
+ "client_secret": "secret",
+ }
+ database.db_engine_spec.get_oauth2_token.return_value = {
+ "access_token": "test_access_token",
+ "expires_in": 3600,
+ "refresh_token": "test_refresh_token",
+ }
+ return database
+
+
[email protected]
+def mock_state() -> str:
+ return encode_oauth2_state(
+ {
+ "user_id": 1,
+ "database_id": 123,
+ "default_redirect_uri": "http://localhost:8088/api/v1/oauth2/",
+ "tab_id": "1234",
+ }
+ )
+
+
[email protected]
+def mock_parameters(mock_state: str) -> dict[str, Any]:
+ return {"code": "test_code", "state": mock_state}
+
+
+def test_validate_success(
+ mocker: MockerFixture,
+ mock_database: MagicMock,
+ mock_state: str,
+ mock_parameters: OAuth2ProviderResponseSchema,
+) -> None:
+ mocker.patch("superset.utils.oauth2.decode_oauth2_state",
return_value=mock_state)
+ mocker.patch.object(
+ DatabaseUserOAuth2TokensDAO,
+ "get_database",
+ return_value=mock_database,
+ )
+
+ command = OAuth2StoreTokenCommand(mock_parameters)
+ command.validate()
+
+ assert command._database == mock_database
+ assert command._state == decode_oauth2_state(mock_state)
+
+
+def test_validate_database_not_found(
+ mocker: MockerFixture,
+ mock_parameters: OAuth2ProviderResponseSchema,
+) -> None:
+ mocker.patch(
+ "superset.utils.oauth2.decode_oauth2_state",
+ return_value={"database_id": 999},
+ )
+ mocker.patch.object(DatabaseUserOAuth2TokensDAO, "get_database",
return_value=None)
+
+ command = OAuth2StoreTokenCommand(mock_parameters)
+ with pytest.raises(DatabaseNotFoundError, match="Database not found"):
+ command.validate()
+
+
+def test_validate_oauth2_error(mock_parameters: OAuth2ProviderResponseSchema)
-> None:
+ mock_parameters["error"] = "OAuth2 failure"
+ command = OAuth2StoreTokenCommand(mock_parameters)
+ with pytest.raises(OAuth2Error, match="Something went wrong while doing
OAuth2"):
+ command.validate()
+
+
+def test_run_success(
+ mocker: MockerFixture,
+ mock_database: MagicMock,
+ mock_state: str,
+ mock_parameters: OAuth2ProviderResponseSchema,
+) -> None:
+ mocker.patch.object(
+ DatabaseUserOAuth2TokensDAO,
+ "get_database",
+ return_value=mock_database,
+ )
+ mocker.patch.object(
+ DatabaseUserOAuth2TokensDAO,
+ "find_one_or_none",
+ return_value=None,
+ )
+ mocker.patch.object(DatabaseUserOAuth2TokensDAO, "delete")
+ mock_create = mocker.patch.object(
+ DatabaseUserOAuth2TokensDAO,
+ "create",
+ return_value="new_token",
+ )
+ mocker.patch("superset.utils.oauth2.decode_oauth2_state",
return_value=mock_state)
+
+ command = OAuth2StoreTokenCommand(mock_parameters)
+ result = command.run()
+
+ assert result == "new_token"
+ mock_create.assert_called_once()
+
+
+def test_run_existing_token(
+ mocker: MockerFixture,
+ mock_database: MagicMock,
+ mock_state: str,
+ mock_parameters: OAuth2ProviderResponseSchema,
+) -> None:
+ mocker.patch.object(
+ DatabaseUserOAuth2TokensDAO,
+ "get_database",
+ return_value=mock_database,
+ )
+ existing_token = MagicMock()
+ mocker.patch.object(
+ DatabaseUserOAuth2TokensDAO,
+ "find_one_or_none",
+ return_value=existing_token,
+ )
+ mock_delete = mocker.patch.object(DatabaseUserOAuth2TokensDAO, "delete")
+ mock_create = mocker.patch.object(
+ DatabaseUserOAuth2TokensDAO,
+ "create",
+ return_value="new_token",
+ )
+ mocker.patch("superset.utils.oauth2.decode_oauth2_state",
return_value=mock_state)
+
+ command = OAuth2StoreTokenCommand(mock_parameters)
+ result = command.run()
+
+ assert result == "new_token"
+ mock_delete.assert_called_once_with([existing_token])
+ mock_create.assert_called_once()
diff --git a/tests/unit_tests/databases/api_test.py
b/tests/unit_tests/databases/api_test.py
index 64d99638de..f7a3da00b1 100644
--- a/tests/unit_tests/databases/api_test.py
+++ b/tests/unit_tests/databases/api_test.py
@@ -41,7 +41,9 @@ from superset.db_engine_specs.sqlite import SqliteEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import OAuth2RedirectError, SupersetSecurityException
from superset.sql_parse import Table
+from superset.superset_typing import OAuth2State
from superset.utils import json
+from superset.utils.oauth2 import encode_oauth2_state
from tests.unit_tests.fixtures.common import (
create_columnar_file,
create_csv_file,
@@ -752,6 +754,7 @@ def test_oauth2_happy_path(
Database.metadata.create_all(session.get_bind()) # pylint:
disable=no-member
db.session.add(
Database(
+ id=1,
database_name="my_db",
sqlalchemy_uri="sqlite://",
uuid=UUID("7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb"),
@@ -771,13 +774,12 @@ def test_oauth2_happy_path(
"refresh_token": "ZZZ",
}
- state = {
+ state: OAuth2State = {
"user_id": 1,
"database_id": 1,
- "tab_id": 42,
+ "tab_id": "42",
+ "default_redirect_uri": "http://localhost:8088/api/v1/oauth2/",
}
- decode_oauth2_state =
mocker.patch("superset.databases.api.decode_oauth2_state")
- decode_oauth2_state.return_value = state
mocker.patch("superset.databases.api.render_template", return_value="OK")
@@ -785,13 +787,12 @@ def test_oauth2_happy_path(
response = client.get(
"/api/v1/database/oauth2/",
query_string={
- "state": "some%2Estate",
+ "state": encode_oauth2_state(state),
"code": "XXX",
},
)
assert response.status_code == 200
- decode_oauth2_state.assert_called_with("some%2Estate")
get_oauth2_token.assert_called_with({"id": "one", "secret": "two"}, "XXX")
token = db.session.query(DatabaseUserOAuth2Tokens).one()
@@ -841,13 +842,12 @@ def test_oauth2_permissions(
"refresh_token": "ZZZ",
}
- state = {
+ state: OAuth2State = {
"user_id": 1,
"database_id": 1,
- "tab_id": 42,
+ "tab_id": "42",
+ "default_redirect_uri": "http://localhost:8088/api/v1/oauth2/",
}
- decode_oauth2_state =
mocker.patch("superset.databases.api.decode_oauth2_state")
- decode_oauth2_state.return_value = state
mocker.patch("superset.databases.api.render_template", return_value="OK")
@@ -855,13 +855,12 @@ def test_oauth2_permissions(
response = client.get(
"/api/v1/database/oauth2/",
query_string={
- "state": "some%2Estate",
+ "state": encode_oauth2_state(state),
"code": "XXX",
},
)
assert response.status_code == 200
- decode_oauth2_state.assert_called_with("some%2Estate")
get_oauth2_token.assert_called_with({"id": "one", "secret": "two"}, "XXX")
token = db.session.query(DatabaseUserOAuth2Tokens).one()
@@ -916,13 +915,12 @@ def test_oauth2_multiple_tokens(
},
]
- state = {
+ state: OAuth2State = {
"user_id": 1,
"database_id": 1,
- "tab_id": 42,
+ "tab_id": "42",
+ "default_redirect_uri": "http://localhost:8088/api/v1/oauth2/",
}
- decode_oauth2_state =
mocker.patch("superset.databases.api.decode_oauth2_state")
- decode_oauth2_state.return_value = state
mocker.patch("superset.databases.api.render_template", return_value="OK")
@@ -930,7 +928,7 @@ def test_oauth2_multiple_tokens(
response = client.get(
"/api/v1/database/oauth2/",
query_string={
- "state": "some%2Estate",
+ "state": encode_oauth2_state(state),
"code": "XXX",
},
)
@@ -939,7 +937,7 @@ def test_oauth2_multiple_tokens(
response = client.get(
"/api/v1/database/oauth2/",
query_string={
- "state": "some%2Estate",
+ "state": encode_oauth2_state(state),
"code": "XXX",
},
)