This is an automated email from the ASF dual-hosted git repository.
beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 5660f8e554 feat: OAuth2 client initial work (#29109)
5660f8e554 is described below
commit 5660f8e5542b78e098c42306633c182d9a631d63
Author: Beto Dealmeida <[email protected]>
AuthorDate: Sun Jun 9 22:11:18 2024 -0400
feat: OAuth2 client initial work (#29109)
---
superset/db_engine_specs/base.py | 21 ++++++---
superset/models/core.py | 54 ++++++++++++++---------
superset/utils/oauth2.py | 14 +++++-
tests/unit_tests/conftest.py | 3 ++
tests/unit_tests/models/core_test.py | 83 ++++++++++++++++++++++++++++++++++++
tests/unit_tests/sql_lab_test.py | 62 +++++++++++++++++++++++++++
6 files changed, 210 insertions(+), 27 deletions(-)
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 548fb390d8..cd37e4e602 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -1610,9 +1610,11 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
@classmethod
def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[Any]:
return [
- literal_column(query_as)
- if (query_as := c.get("query_as"))
- else column(c["column_name"])
+ (
+ literal_column(query_as)
+ if (query_as := c.get("query_as"))
+ else column(c["column_name"])
+ )
for c in cols
]
@@ -1828,13 +1830,18 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
cursor.arraysize = cls.arraysize
try:
cursor.execute(query)
- except cls.oauth2_exception as ex:
- if database.is_oauth2_enabled() and g and g.user:
- cls.start_oauth2_dance(database)
- raise cls.get_dbapi_mapped_exception(ex) from ex
except Exception as ex:
+ if database.is_oauth2_enabled() and cls.needs_oauth2(ex):
+ cls.start_oauth2_dance(database)
raise cls.get_dbapi_mapped_exception(ex) from ex
+ @classmethod
+ def needs_oauth2(cls, ex: Exception) -> bool:
+ """
+ Check if the exception is one that indicates OAuth2 is needed.
+ """
+ return g and hasattr(g, "user") and isinstance(ex,
cls.oauth2_exception)
+
@classmethod
def make_label_compatible(cls, label: str) -> str | quoted_name:
"""
diff --git a/superset/models/core.py b/superset/models/core.py
index e6d97a197b..c8c875e435 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -29,7 +29,7 @@ from contextlib import closing, contextmanager, nullcontext,
suppress
from copy import deepcopy
from datetime import datetime
from functools import lru_cache
-from typing import Any, Callable, TYPE_CHECKING
+from typing import Any, Callable, cast, TYPE_CHECKING
import numpy
import pandas as pd
@@ -78,7 +78,7 @@ from superset.superset_typing import OAuth2ClientConfig,
ResultSetColumnType
from superset.utils import cache as cache_util, core as utils, json
from superset.utils.backports import StrEnum
from superset.utils.core import DatasourceName, get_username
-from superset.utils.oauth2 import get_oauth2_access_token
+from superset.utils.oauth2 import get_oauth2_access_token,
OAuth2ClientConfigSchema
config = app.config
custom_password_store = config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]
@@ -554,17 +554,23 @@ class Database(Model, AuditMixinNullable,
ImportExportMixin): # pylint: disable
nullpool=nullpool,
source=source,
) as engine:
- with closing(engine.raw_connection()) as conn:
- # pre-session queries are used to set the selected schema and,
in the
- # future, the selected catalog
- for prequery in self.db_engine_spec.get_prequeries(
- catalog=catalog,
- schema=schema,
- ):
- cursor = conn.cursor()
- cursor.execute(prequery)
+ try:
+ with closing(engine.raw_connection()) as conn:
+ # pre-session queries are used to set the selected schema
and, in the
+ # future, the selected catalog
+ for prequery in self.db_engine_spec.get_prequeries(
+ catalog=catalog,
+ schema=schema,
+ ):
+ cursor = conn.cursor()
+ cursor.execute(prequery)
- yield conn
+ yield conn
+
+ except Exception as ex:
+ if self.is_oauth2_enabled() and
self.db_engine_spec.needs_oauth2(ex):
+ self.db_engine_spec.start_oauth2_dance(self)
+ raise ex
def get_default_catalog(self) -> str | None:
"""
@@ -1063,20 +1069,30 @@ class Database(Model, AuditMixinNullable,
ImportExportMixin): # pylint: disable
"""
Is OAuth2 enabled in the database for authentication?
- Currently this looks for a global config at the DB engine spec level,
but in the
- future we want to be allow admins to create custom OAuth2 clients from
the
- Superset UI, and assign them to specific databases.
+ Currently this checks for configuration stored in the database
`extra`, and then
+ for a global config at the DB engine spec level. In the future we want
to allow
+ admins to create custom OAuth2 clients from the Superset UI, and
assign them to
+ specific databases.
"""
- return self.db_engine_spec.is_oauth2_enabled()
+ encrypted_extra = json.loads(self.encrypted_extra or "{}")
+ oauth2_client_info = encrypted_extra.get("oauth2_client_info", {})
+ return bool(oauth2_client_info) or
self.db_engine_spec.is_oauth2_enabled()
def get_oauth2_config(self) -> OAuth2ClientConfig | None:
"""
Return OAuth2 client configuration.
- This includes client ID, client secret, scope, redirect URI,
endpointsm etc.
- Currently this reads the global DB engine spec config, but in the
future it
- should first check if there's a custom client assigned to the database.
+ Currently this checks for configuration stored in the database
`extra`, and then
+ for a global config at the DB engine spec level. In the future we want
to allow
+ admins to create custom OAuth2 clients from the Superset UI, and
assign them to
+ specific databases.
"""
+ encrypted_extra = json.loads(self.encrypted_extra or "{}")
+ if oauth2_client_info := encrypted_extra.get("oauth2_client_info"):
+ schema = OAuth2ClientConfigSchema()
+ client_config = schema.load(oauth2_client_info)
+ return cast(OAuth2ClientConfig, client_config)
+
return self.db_engine_spec.get_oauth2_config()
diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py
index 9cc58a0b7f..bc4805fd81 100644
--- a/superset/utils/oauth2.py
+++ b/superset/utils/oauth2.py
@@ -22,7 +22,7 @@ from typing import Any, TYPE_CHECKING
import backoff
import jwt
-from flask import current_app
+from flask import current_app, url_for
from marshmallow import EXCLUDE, fields, post_load, Schema
from superset import db
@@ -180,3 +180,15 @@ def decode_oauth2_state(encoded_state: str) -> OAuth2State:
state = oauth2_state_schema.load(payload)
return state
+
+
+class OAuth2ClientConfigSchema(Schema):
+ id = fields.String(required=True)
+ secret = fields.String(required=True)
+ scope = fields.String(required=True)
+ redirect_uri = fields.String(
+ required=False,
+ load_default=lambda: url_for("DatabaseRestApi.oauth2", _external=True),
+ )
+ authorization_request_uri = fields.String(required=True)
+ token_request_uri = fields.String(required=True)
diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py
index 3905a15b32..2b8f39d6dd 100644
--- a/tests/unit_tests/conftest.py
+++ b/tests/unit_tests/conftest.py
@@ -90,6 +90,9 @@ def app(request: SubRequest) -> Iterator[SupersetApp]:
app.config["RATELIMIT_ENABLED"] = False
app.config["CACHE_CONFIG"] = {}
app.config["DATA_CACHE_CONFIG"] = {}
+ app.config["SERVER_NAME"] = "example.com"
+ app.config["APPLICATION_ROOT"] = "/"
+ app.config["PREFERRED_URL_SCHEME="] = "http"
# loop over extra configs passed in by tests
# and update the app config
diff --git a/tests/unit_tests/models/core_test.py
b/tests/unit_tests/models/core_test.py
index 2004ff482f..c4d642baf5 100644
--- a/tests/unit_tests/models/core_test.py
+++ b/tests/unit_tests/models/core_test.py
@@ -16,6 +16,7 @@
# under the License.
# pylint: disable=import-outside-toplevel
+
from datetime import datetime
import pytest
@@ -24,11 +25,23 @@ from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import make_url
from superset.connectors.sqla.models import SqlaTable, TableColumn
+from superset.exceptions import OAuth2Error, OAuth2RedirectError
from superset.models.core import Database
from superset.sql_parse import Table
from superset.utils import json
from tests.unit_tests.conftest import with_feature_flags
+# sample config for OAuth2 tests
+oauth2_client_info = {
+ "oauth2_client_info": {
+ "id": "my_client_id",
+ "secret": "my_client_secret",
+ "authorization_request_uri":
"https://abcd1234.snowflakecomputing.com/oauth/authorize",
+ "token_request_uri":
"https://abcd1234.snowflakecomputing.com/oauth/token-request",
+ "scope": "refresh_token session:role:SYSADMIN",
+ }
+}
+
def test_get_metrics(mocker: MockFixture) -> None:
"""
@@ -378,3 +391,73 @@ def test_get_sqla_engine_user_impersonation_email(mocker:
MockFixture) -> None:
make_url("trino:///"),
connect_args={"user": "alice.doe", "source": "Apache Superset"},
)
+
+
+def test_is_oauth2_enabled() -> None:
+ """
+ Test the `is_oauth2_enabled` method.
+ """
+ database = Database(
+ database_name="db",
+ sqlalchemy_uri="postgresql://user:password@host:5432/examples",
+ )
+
+ assert not database.is_oauth2_enabled()
+
+ database.encrypted_extra = json.dumps(oauth2_client_info)
+ assert database.is_oauth2_enabled()
+
+
+def test_get_oauth2_config(app_context: None) -> None:
+ """
+ Test the `get_oauth2_config` method.
+ """
+ database = Database(
+ database_name="db",
+ sqlalchemy_uri="postgresql://user:password@host:5432/examples",
+ )
+
+ assert database.get_oauth2_config() is None
+
+ database.encrypted_extra = json.dumps(oauth2_client_info)
+ assert database.get_oauth2_config() == {
+ "id": "my_client_id",
+ "secret": "my_client_secret",
+ "authorization_request_uri":
"https://abcd1234.snowflakecomputing.com/oauth/authorize",
+ "token_request_uri":
"https://abcd1234.snowflakecomputing.com/oauth/token-request",
+ "scope": "refresh_token session:role:SYSADMIN",
+ "redirect_uri": "http://example.com/api/v1/database/oauth2/",
+ }
+
+
+def test_raw_connection_oauth(mocker: MockFixture) -> None:
+ """
+ Test that we can start OAuth2 from `raw_connection()` errors.
+
+ Some databases that use OAuth2 need to trigger the flow when the
connection is
+ created, rather than when the query runs. This happens when the SQLAlchemy
engine
+ URI cannot be built without the user personal token.
+
+ This test verifies that the exception is captured and raised correctly so
that the
+ frontend can trigger the OAuth2 dance.
+ """
+ g = mocker.patch("superset.db_engine_specs.base.g")
+ g.user = mocker.MagicMock()
+ g.user.id = 42
+
+ database = Database(
+ id=1,
+ database_name="my_db",
+ sqlalchemy_uri="sqlite://",
+ encrypted_extra=json.dumps(oauth2_client_info),
+ )
+ database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore
+ get_sqla_engine = mocker.patch.object(database, "get_sqla_engine")
+ get_sqla_engine().__enter__().raw_connection.side_effect = OAuth2Error(
+ "OAuth2 required"
+ )
+
+ with pytest.raises(OAuth2RedirectError) as excinfo:
+ with database.get_raw_connection() as conn:
+ conn.cursor()
+ assert str(excinfo.value) == "You don't have permission to access the
data."
diff --git a/tests/unit_tests/sql_lab_test.py b/tests/unit_tests/sql_lab_test.py
index 3b2e7690e1..6ac055df13 100644
--- a/tests/unit_tests/sql_lab_test.py
+++ b/tests/unit_tests/sql_lab_test.py
@@ -16,12 +16,22 @@
# under the License.
# pylint: disable=import-outside-toplevel, invalid-name, unused-argument,
too-many-locals
+import json
+from uuid import UUID
+
import sqlparse
+from freezegun import freeze_time
from pytest_mock import MockerFixture
from sqlalchemy.orm.session import Session
from superset import db
+from superset.common.db_query_status import QueryStatus
+from superset.errors import ErrorLevel, SupersetErrorType
+from superset.exceptions import OAuth2Error
+from superset.models.core import Database
+from superset.sql_lab import get_sql_results
from superset.utils.core import override_user
+from tests.unit_tests.models.core_test import oauth2_client_info
def test_execute_sql_statement(mocker: MockerFixture, app: None) -> None:
@@ -218,3 +228,55 @@ def test_sql_lab_insert_rls_as_subquery(
query.executed_sql
== "SELECT c FROM (SELECT * FROM t WHERE (t.c > 5)) AS t\nLIMIT 6"
)
+
+
+@freeze_time("2021-04-01T00:00:00Z")
+def test_get_sql_results_oauth2(mocker: MockerFixture, app) -> None:
+ """
+ Test that `get_sql_results` works with OAuth2.
+ """
+ app_context = app.test_request_context()
+ app_context.push()
+
+ mocker.patch(
+ "superset.db_engine_specs.base.uuid4",
+ return_value=UUID("fb11f528-6eba-4a8a-837e-6b0d39ee9187"),
+ )
+
+ g = mocker.patch("superset.db_engine_specs.base.g")
+ g.user = mocker.MagicMock()
+ g.user.id = 42
+
+ database = Database(
+ id=1,
+ database_name="my_db",
+ sqlalchemy_uri="sqlite://",
+ encrypted_extra=json.dumps(oauth2_client_info),
+ )
+ database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore
+ get_sqla_engine = mocker.patch.object(database, "get_sqla_engine")
+ get_sqla_engine().__enter__().raw_connection.side_effect = OAuth2Error(
+ "OAuth2 required"
+ )
+
+ query = mocker.MagicMock()
+ query.database = database
+ mocker.patch("superset.sql_lab.get_query", return_value=query)
+
+ payload = get_sql_results(query_id=1, rendered_query="SELECT 1")
+ assert payload == {
+ "status": QueryStatus.FAILED,
+ "error": "You don't have permission to access the data.",
+ "errors": [
+ {
+ "message": "You don't have permission to access the data.",
+ "error_type": SupersetErrorType.OAUTH2_REDIRECT,
+ "level": ErrorLevel.WARNING,
+ "extra": {
+ "url":
"https://abcd1234.snowflakecomputing.com/oauth/authorize?scope=refresh_token+session%3Arole%3ASYSADMIN&access_type=offline&include_granted_scopes=false&response_type=code&state=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9%252EeyJleHAiOjE2MTcyMzU1MDAsImRhdGFiYXNlX2lkIjoxLCJ1c2VyX2lkIjo0MiwiZGVmYXVsdF9yZWRpcmVjdF91cmkiOiJodHRwOi8vZXhhbXBsZS5jb20vYXBpL3YxL2RhdGFiYXNlL29hdXRoMi8iLCJ0YWJfaWQiOiJmYjExZjUyOC02ZWJhLTRhOGEtODM3ZS02YjBkMzllZTkxODcifQ%252Ec_m_35xwwSrLgCXwV4aKO192
[...]
+ "tab_id": "fb11f528-6eba-4a8a-837e-6b0d39ee9187",
+ "redirect_uri":
"http://example.com/api/v1/database/oauth2/",
+ },
+ }
+ ],
+ }