This is an automated email from the ASF dual-hosted git repository. beto pushed a commit to branch sip-85-b in repository https://gitbox.apache.org/repos/asf/superset.git
commit d08859212cae98c27152e20b7a125eb637aeb4c1 Author: Beto Dealmeida <[email protected]> AuthorDate: Wed Apr 3 10:12:51 2024 -0400 WIP --- superset/config.py | 20 ++- superset/connectors/sqla/utils.py | 2 +- superset/db_engine_specs/README.md | 62 ++++---- superset/db_engine_specs/base.py | 174 ++++++++++++++++------- superset/db_engine_specs/gsheets.py | 89 +----------- superset/db_engine_specs/hive.py | 1 + superset/db_engine_specs/presto.py | 2 + superset/db_engine_specs/trino.py | 2 + superset/models/core.py | 36 ++++- superset/sql_validators/presto_db.py | 2 +- superset/superset_typing.py | 49 +++++++ superset/utils/oauth2.py | 12 +- tests/unit_tests/db_engine_specs/test_gsheets.py | 10 +- 13 files changed, 271 insertions(+), 190 deletions(-) diff --git a/superset/config.py b/superset/config.py index b6dbfc9fee..bb5b04b232 100644 --- a/superset/config.py +++ b/superset/config.py @@ -1409,12 +1409,20 @@ TEST_DATABASE_CONNECTION_TIMEOUT = timedelta(seconds=30) # Details needed for databases that allows user to authenticate using personal # OAuth2 tokens. See https://github.com/apache/superset/issues/20300 for more -# information -DATABASE_OAUTH2_CREDENTIALS: dict[str, dict[str, Any]] = { +# information. The scope and URIs are optional. +DATABASE_OAUTH2_CLIENTS: dict[str, dict[str, Any]] = { # "Google Sheets": { - # "CLIENT_ID": "XXX.apps.googleusercontent.com", - # "CLIENT_SECRET": "GOCSPX-YYY", - # "BASEURL": "https://accounts.google.com/o/oauth2/v2/auth", + # "id": "XXX.apps.googleusercontent.com", + # "secret": "GOCSPX-YYY", + # "scope": " ".join( + # [ + # "https://www.googleapis.com/auth/drive.readonly", + # "https://www.googleapis.com/auth/spreadsheets", + # "https://spreadsheets.google.com/feeds", + # ] + # ), + # "authorization_request_uri": "https://accounts.google.com/o/oauth2/v2/auth", + # "token_request_uri": "https://oauth2.googleapis.com/token", # }, } # OAuth2 state is encoded in a JWT using the alogorithm below. @@ -1425,6 +1433,8 @@ DATABASE_OAUTH2_JWT_ALGORITHM = "HS256" # applications. In that case, the proxy can forward the request to the correct instance # by looking at the `default_redirect_uri` attribute in the OAuth2 state object. # DATABASE_OAUTH2_REDIRECT_URI = "http://localhost:8088/api/v1/database/oauth2/" +# Timeout when fetching access and refresh tokens. +DATABASE_OAUTH2_TIMEOUT = timedelta(seconds=30) # Enable/disable CSP warning CONTENT_SECURITY_POLICY_WARNING = True diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py index d0922e40f3..4bc11aee42 100644 --- a/superset/connectors/sqla/utils.py +++ b/superset/connectors/sqla/utils.py @@ -145,7 +145,7 @@ def get_columns_description( cursor = conn.cursor() query = database.apply_limit_to_sql(query, limit=1) cursor.execute(query) - db_engine_spec.execute(cursor, query, database.id) + db_engine_spec.execute(cursor, query, database) result = db_engine_spec.fetch_data(cursor, limit=1) result_set = SupersetResultSet(result, cursor.description, db_engine_spec) return result_set.columns diff --git a/superset/db_engine_specs/README.md b/superset/db_engine_specs/README.md index 0be1f29146..d5d0d6cb4a 100644 --- a/superset/db_engine_specs/README.md +++ b/superset/db_engine_specs/README.md @@ -547,65 +547,53 @@ Alternatively, it's also possible to impersonate users by implementing the `upda Support for authenticating to a database using personal OAuth2 access tokens was introduced in [SIP-85](https://github.com/apache/superset/issues/20300). The Google Sheets DB engine spec is the reference implementation. -To add support for OAuth2 to a DB engine spec, the following attribute and methods are needed: +To add support for OAuth2 to a DB engine spec, the following attributes are needed: ```python class BaseEngineSpec: + supports_oauth2 = True oauth2_exception = OAuth2RedirectError - @staticmethod - def is_oauth2_enabled() -> bool: - return False - - @staticmethod - def get_oauth2_authorization_uri(state: OAuth2State) -> str: - raise NotImplementedError() - - @staticmethod - def get_oauth2_token(code: str, state: OAuth2State) -> OAuth2TokenResponse: - raise NotImplementedError() - - @staticmethod - def get_oauth2_fresh_token(refresh_token: str) -> OAuth2TokenResponse: - raise NotImplementedError() + oauth2_scope = " ".join([ + "https://example.org/scope1", + "https://example.org/scope2", + ]) + oauth2_authorization_request_uri = "https://example.org/authorize" + oauth2_token_request_uri = "https://example.org/token" ``` The `oauth2_exception` is an exception that is raised by `cursor.execute` when OAuth2 is needed. This will start the OAuth2 dance when `BaseEngineSpec.execute` is called, by returning the custom error `OAUTH2_REDIRECT` to the frontend. If the database driver doesn't have a specific exception, it might be necessary to overload the `execute` method in the DB engine spec, so that the `BaseEngineSpec.start_oauth2_dance` method gets called whenever OAuth2 is needed. -The first method, `is_oauth2_enabled`, is used to inform if the database supports OAuth2. This can be dynamic; for example, the Google Sheets DB engine spec checks if the Superset configuration has the necessary section: - -```python -from flask import current_app - +The DB engine should implement logic in either `get_url_for_impersonation` or `update_impersonation_config` to update the connection with the personal access token. See the Google Sheets DB engine spec for a reference implementation. -class GSheetsEngineSpec(ShillelaghEngineSpec): - @staticmethod - def is_oauth2_enabled() -> bool: - return "Google Sheets" in current_app.config["DATABASE_OAUTH2_CREDENTIALS"] -``` - -Where the configuration for OAuth2 would look like this: +Currently OAuth2 needs to be configured at the DB engine spec level, ie, with one client for each DB engien spec. The configuration lives in `superset_config.py`: ```python # superset_config.py -DATABASE_OAUTH2_CREDENTIALS = { +DATABASE_OAUTH2_CLIENTS = { "Google Sheets": { - "CLIENT_ID": "XXX.apps.googleusercontent.com", - "CLIENT_SECRET": "GOCSPX-YYY", + "id": "XXX.apps.googleusercontent.com", + "secret": "GOCSPX-YYY", + "scope": " ".join( + [ + "https://www.googleapis.com/auth/drive.readonly", + "https://www.googleapis.com/auth/spreadsheets", + "https://spreadsheets.google.com/feeds", + ], + ), + "authorization_request_uri": "https://accounts.google.com/o/oauth2/v2/auth", + "token_request_uri": "https://oauth2.googleapis.com/token", }, } DATABASE_OAUTH2_JWT_ALGORITHM = "HS256" DATABASE_OAUTH2_REDIRECT_URI = "http://localhost:8088/api/v1/database/oauth2/" +DATABASE_OAUTH2_TIMEOUT = timedelta(seconds=30) ``` -The second method, `get_oauth2_authorization_uri`, is responsible for building the URL where the user is sent to initiate OAuth2. This method receives a `state`. The state is an encoded JWT that is passed to the OAuth2 provider, and is received unmodified when the user is redirected back to Superset. The default state contains the user ID and the database ID, so that Superset can know where to store the received OAuth2 tokens. - -Additionally, the state also contains a `tab_id`, which is a random UUID4 used as a shared secret for communication between browser tabs. When OAuth2 starts, Superset will open a new browser tab, where the user will grant permissions to Superset. When authentication is complete and successful this opened tab will send a message to the original tab, so that the original query can be re-run. The `tab_id` is sent by the opened tab and verified by the original tab to prevent malicious messag [...] - -State also contains a `defaul_redirect_uri`, which is the enpoint in Supeset that receives the tokens from the OAuth2 provider (`/api/v1/database/oauth2/`). The redirect URL can be overwritten in the config file via the `DATABASE_OAUTH2_REDIRECT_URI` parameter. This might be useful where you have multiple Superset instances. Since the OAuth2 provider requires the redirect URL to be registered a priori, it might be easier (or needed) to register a single URL for a proxy service; the proxy [...] +When configuring a client only the ID and secret are required; the DB engine spec should have default values for the scope and endpoints. The `DATABASE_OAUTH2_REDIRECT_URI` attribute is optional, and defaults to `/api/v1/databases/oauth2/` in Superset. -Finally, `get_oauth2_token` and `get_oauth2_fresh_token` are used to actually retrieve a token and refresh an expired token, respectively. +In the future we plan to support adding custom clients via the Superset UI, and being able to manually assign clients to specific databases. ### File upload diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 12797fc6a3..8fdffe2095 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -33,9 +33,11 @@ from typing import ( TypedDict, Union, ) +from urllib.parse import urlencode, urljoin from uuid import uuid4 import pandas as pd +import requests import sqlparse from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin @@ -60,13 +62,20 @@ from superset import security_manager, sql_parse from superset.constants import TimeGrain as TimeGrainConstants from superset.databases.utils import make_url_safe from superset.errors import ErrorLevel, SupersetError, SupersetErrorType -from superset.exceptions import OAuth2Error, OAuth2RedirectError +from superset.exceptions import OAuth2RedirectError from superset.sql_parse import ParsedQuery, SQLScript, Table -from superset.superset_typing import ResultSetColumnType, SQLAColumnType +from superset.superset_typing import ( + OAuth2ClientConfig, + OAuth2State, + OAuth2TokenResponse, + ResultSetColumnType, + SQLAColumnType, +) from superset.utils import core as utils from superset.utils.core import ColumnSpec, GenericDataType from superset.utils.hashing import md5_sha_from_str from superset.utils.network import is_hostname_valid, is_port_open +from superset.utils.oauth2 import encode_oauth2_state if TYPE_CHECKING: from superset.connectors.sqla.models import TableColumn @@ -173,31 +182,6 @@ class MetricType(TypedDict, total=False): extra: str | None -class OAuth2TokenResponse(TypedDict, total=False): - """ - Type for an OAuth2 response when exchanging or refreshing tokens. - """ - - access_token: str - expires_in: int - scope: str - token_type: str - - # only present when exchanging code for refresh/access tokens - refresh_token: str - - -class OAuth2State(TypedDict): - """ - Type for the state passed during OAuth2. - """ - - database_id: int - user_id: int - default_redirect_uri: str - tab_id: str - - class BaseEngineSpec: # pylint: disable=too-many-public-methods """Abstract class for database engine specific configurations @@ -425,15 +409,25 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # Can the catalog be changed on a per-query basis? supports_dynamic_catalog = False + # Does the engine supports OAuth 2.0? This requires logic to be added to one of the + # the user impersonation methods to handle personal tokens. + supports_oauth2 = False + oauth2_scope = "" + oauth2_authorization_request_uri = "" + oauth2_token_request_uri = "" + # Driver-specific exception that should be mapped to OAuth2RedirectError - oauth2_exception = OAuth2RedirectError + oauth2_exception: Exception = OAuth2RedirectError - @staticmethod - def is_oauth2_enabled() -> bool: - return False + @classmethod + def is_oauth2_enabled(cls) -> bool: + return ( + cls.supports_oauth2 + and cls.engine_name in current_app.config["DATABASE_OAUTH2_CLIENTS"] + ) @classmethod - def start_oauth2_dance(cls, database_id: int) -> None: + def start_oauth2_dance(cls, database: "Database") -> None: """ Start the OAuth2 dance. @@ -446,10 +440,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods """ tab_id = str(uuid4()) default_redirect_uri = url_for("DatabaseRestApi.oauth2", _external=True) - redirect_uri = current_app.config.get( - "DATABASE_OAUTH2_REDIRECT_URI", - default_redirect_uri, - ) # The state is passed to the OAuth2 provider, and sent back to Superset after # the user authorizes the access. The redirect endpoint in Superset can then @@ -457,7 +447,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # belongs to. state: OAuth2State = { # Database ID and user ID are the primary key associated with the token. - "database_id": database_id, + "database_id": database.id, "user_id": g.user.id, # In multi-instance deployments there might be a single proxy handling # redirects, with a custom `DATABASE_OAUTH2_REDIRECT_URI`. Since the OAuth2 @@ -473,30 +463,112 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # message. "tab_id": tab_id, } - oauth_url = cls.get_oauth2_authorization_uri(state) + config = database.get_oauth2_config() + oauth_url = cls.get_oauth2_authorization_uri(config, state) - raise OAuth2RedirectError(oauth_url, tab_id, redirect_uri) + raise OAuth2RedirectError(oauth_url, tab_id, default_redirect_uri) - @staticmethod - def get_oauth2_authorization_uri(state: OAuth2State) -> str: + @classmethod + def get_oauth2_config(cls) -> OAuth2ClientConfig: + """ + Build the OAuth2 client config. + + Currently this is built based on the `config.py` file, since clients are defined + at the DB engine spec level. In the future we'll allow admins to create custom + OAuth2 clients and assign them to specific databases. + """ + redirect_uri = current_app.config.get( + "DATABASE_OAUTH2_REDIRECT_URI", + url_for("DatabaseRestApi.oauth2", _external=True), + ) + + oauth2_config = current_app.config["DATABASE_OAUTH2_CLIENTS"] + try: + db_engine_spec_config = oauth2_config[cls.engine_name] + except KeyError: + raise Exception("TODO") + + config: OAuth2ClientConfig = { + "id": db_engine_spec_config["id"], + "secret": db_engine_spec_config["secret"], + "scope": db_engine_spec_config.get("scope") or cls.oauth2_scope, + "redirect_uri": redirect_uri, + "authorization_request_uri": db_engine_spec_config.get( + "authorization_request_uri" + ) + or cls.oauth2_authorization_request_uri, + "token_request_uri": db_engine_spec_config.get("token_request_uri") + or cls.oauth2_token_request_uri, + } + + return config + + @classmethod + def get_oauth2_authorization_uri( + cls, + config: OAuth2ClientConfig, + state: OAuth2State, + ) -> str: """ Return URI for initial OAuth2 request. """ - raise OAuth2Error("Subclasses must implement `get_oauth2_authorization_uri`") + uri = config["authorization_request_uri"] + params = { + "scope": config["scope"], + "access_type": "offline", + "include_granted_scopes": "false", + "response_type": "code", + "state": encode_oauth2_state(state), + "redirect_uri": config["redirect_uri"], + "client_id": config["id"], + "prompt": "consent", + } + return urljoin(uri, "?" + urlencode(params)) - @staticmethod - def get_oauth2_token(code: str, state: OAuth2State) -> OAuth2TokenResponse: + @classmethod + def get_oauth2_token( + config: OAuth2ClientConfig, + code: str, + ) -> OAuth2TokenResponse: """ Exchange authorization code for refresh/access tokens. """ - raise OAuth2Error("Subclasses must implement `get_oauth2_token`") + timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds() + uri = config["token_request_uri"] + response = requests.post( + uri, + fields={ + "code": code, + "client_id": config["id"], + "client_secret": config["secret"], + "redirect_uri": config["redirect_uri"], + "grant_type": "authorization_code", + }, + timeout=timeout, + ) + return response.json() @staticmethod - def get_oauth2_fresh_token(refresh_token: str) -> OAuth2TokenResponse: + def get_oauth2_fresh_token( + config: OAuth2ClientConfig, + refresh_token: str, + ) -> OAuth2TokenResponse: """ Refresh an access token that has expired. """ - raise OAuth2Error("Subclasses must implement `get_oauth2_fresh_token`") + timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds() + uri = config["token_request_uri"] + response = requests.post( + uri, + fields={ + "client_id": config["id"], + "client_secret": config["secret"], + "refresh_token": refresh_token, + "grant_type": "refresh_token", + }, + timeout=timeout, + ) + return response.json() @classmethod def get_allows_alias_in_select( @@ -1667,6 +1739,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods connect_args: dict[str, Any], uri: str, username: str | None, + access_token: str | None, ) -> None: """ Update a configuration dictionary @@ -1675,6 +1748,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods :param connect_args: config to be updated :param uri: URI :param username: Effective username + :param access_token: Personal access token for OAuth2 :return: None """ @@ -1683,7 +1757,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods cls, cursor: Any, query: str, - database_id: int, + database: "Database", **kwargs: Any, ) -> None: """ @@ -1703,8 +1777,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods try: cursor.execute(query) except cls.oauth2_exception as ex: - if cls.is_oauth2_enabled() and g.user: - cls.start_oauth2_dance(database_id) + if database.is_oauth2_enabled() and g.user: + cls.start_oauth2_dance(database) raise cls.get_dbapi_mapped_exception(ex) from ex except Exception as ex: raise cls.get_dbapi_mapped_exception(ex) from ex diff --git a/superset/db_engine_specs/gsheets.py b/superset/db_engine_specs/gsheets.py index 28e95811d4..f74982fe4b 100644 --- a/superset/db_engine_specs/gsheets.py +++ b/superset/db_engine_specs/gsheets.py @@ -23,13 +23,11 @@ import logging import re from re import Pattern from typing import Any, TYPE_CHECKING, TypedDict -from urllib.parse import urlencode, urljoin import pandas as pd -import urllib3 from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin -from flask import current_app, g +from flask import g from flask_babel import gettext as __ from marshmallow import fields, Schema from marshmallow.exceptions import ValidationError @@ -42,11 +40,9 @@ from sqlalchemy.engine.url import URL from superset import db, security_manager from superset.constants import PASSWORD_MASK from superset.databases.schemas import encrypted_field_properties, EncryptedString -from superset.db_engine_specs.base import OAuth2State, OAuth2TokenResponse from superset.db_engine_specs.shillelagh import ShillelaghEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import SupersetException -from superset.utils.oauth2 import encode_oauth2_state if TYPE_CHECKING: from superset.models.core import Database @@ -62,7 +58,6 @@ EXAMPLE_GSHEETS_URL = ( SYNTAX_ERROR_REGEX = re.compile('SQLError: near "(?P<server_error>.*?)": syntax error') ma_plugin = MarshmallowPlugin() -http = urllib3.PoolManager() class GSheetsParametersSchema(Schema): @@ -111,7 +106,11 @@ class GSheetsEngineSpec(ShillelaghEngineSpec): supports_file_upload = True - # exception raised by shillelagh that should trigger OAuth2 + # OAuth 2.0 + supports_oauth2 = True + oauth2_scope = " ".join(SCOPES) + oauth2_authorization_request_uri = "https://accounts.google.com/o/oauth2/v2/auth" + oauth2_token_request_uri = "https://oauth2.googleapis.com/token" oauth2_exception = UnauthenticatedError @classmethod @@ -153,82 +152,6 @@ class GSheetsEngineSpec(ShillelaghEngineSpec): return {"metadata": metadata["extra"]} - @staticmethod - def is_oauth2_enabled() -> bool: - """ - Return if OAuth2 is enabled for GSheets. - """ - return "Google Sheets" in current_app.config["DATABASE_OAUTH2_CREDENTIALS"] - - @classmethod - def get_oauth2_authorization_uri(cls, state: OAuth2State) -> str: - """ - Return URI for initial OAuth2 request. - - https://developers.google.com/identity/protocols/oauth2/web-server#creatingclient - """ - config = current_app.config["DATABASE_OAUTH2_CREDENTIALS"]["Google Sheets"] - baseurl = config.get("BASEURL", "https://accounts.google.com/o/oauth2/v2/auth") - redirect_uri = current_app.config.get( - "DATABASE_OAUTH2_REDIRECT_URI", - state["default_redirect_uri"], - ) - - params = { - "scope": " ".join(SCOPES), - "access_type": "offline", - "include_granted_scopes": "false", - "response_type": "code", - "state": encode_oauth2_state(state), - "redirect_uri": redirect_uri, - "client_id": config["CLIENT_ID"], - "prompt": "consent", - } - return urljoin(baseurl, "?" + urlencode(params)) - - @staticmethod - def get_oauth2_token(code: str, state: OAuth2State) -> OAuth2TokenResponse: - """ - Exchange authorization code for refresh/access tokens. - """ - config = current_app.config["DATABASE_OAUTH2_CREDENTIALS"]["Google Sheets"] - redirect_uri = current_app.config.get( - "DATABASE_OAUTH2_REDIRECT_URI", - state["default_redirect_uri"], - ) - - response = http.request( - "POST", - "https://oauth2.googleapis.com/token", - fields={ - "code": code, - "client_id": config["CLIENT_ID"], - "client_secret": config["CLIENT_SECRET"], - "redirect_uri": redirect_uri, - "grant_type": "authorization_code", - }, - ) - return json.loads(response.data.decode("utf-8")) - - @staticmethod - def get_oauth2_fresh_token(refresh_token: str) -> OAuth2TokenResponse: - """ - Refresh an access token that has expired. - """ - config = current_app.config["DATABASE_OAUTH2_CREDENTIALS"]["Google Sheets"] - - response = http.request( - "POST", - "https://oauth2.googleapis.com/token", - fields={ - "client_id": config["CLIENT_ID"], - "client_secret": config["CLIENT_SECRET"], - "refresh_token": refresh_token, - "grant_type": "refresh_token", - }, - ) - return json.loads(response.data.decode("utf-8")) - @classmethod def build_sqlalchemy_uri( cls, diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index a97dd88aef..4b4c2d951b 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -528,6 +528,7 @@ class HiveEngineSpec(PrestoEngineSpec): connect_args: dict[str, Any], uri: str, username: str | None, + access_token: str | None, ) -> None: """ Update a configuration dictionary diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 0df8d53f4f..ab44edc258 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -717,6 +717,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): connect_args: dict[str, Any], uri: str, username: str | None, + access_token: str | None, ) -> None: """ Update a configuration dictionary @@ -724,6 +725,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): :param connect_args: config to be updated :param uri: URI string :param username: Effective username + :param access_token: Personal access token for OAuth2 :return: None """ url = make_url_safe(uri) diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 4513d63c60..4eeac72a7b 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -112,6 +112,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec): connect_args: dict[str, Any], uri: str, username: str | None, + access_token: str | None, ) -> None: """ Update a configuration dictionary @@ -119,6 +120,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec): :param connect_args: config to be updated :param uri: URI string :param username: Effective username + :param access_token: Personal access token for OAuth2 :return: None """ url = make_url_safe(uri) diff --git a/superset/models/core.py b/superset/models/core.py index cf84b90ac2..eb165e4dfa 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -71,7 +71,7 @@ from superset.extensions import ( ) from superset.models.helpers import AuditMixinNullable, ImportExportMixin from superset.result_set import SupersetResultSet -from superset.superset_typing import ResultSetColumnType +from superset.superset_typing import OAuth2ClientConfig, ResultSetColumnType from superset.utils import cache as cache_util, core as utils from superset.utils.backports import StrEnum from superset.utils.core import get_username @@ -467,7 +467,12 @@ class Database( effective_username = self.get_effective_user(sqlalchemy_url) access_token = ( - get_oauth2_access_token(self.id, g.user.id, self.db_engine_spec) + get_oauth2_access_token( + self.get_oauth2_config(), + self.id, + g.user.id, + self.db_engine_spec, + ) if hasattr(g, "user") and hasattr(g.user, "id") else None ) @@ -489,6 +494,7 @@ class Database( connect_args, str(sqlalchemy_url), effective_username, + access_token, ) if connect_args: @@ -599,7 +605,7 @@ class Database( database=None, ) _log_query(sql_) - self.db_engine_spec.execute(cursor, sql_, self.id) + self.db_engine_spec.execute(cursor, sql_, self) cursor.fetchall() if mutate_after_split: @@ -609,10 +615,10 @@ class Database( database=None, ) _log_query(last_sql) - self.db_engine_spec.execute(cursor, last_sql, self.id) + self.db_engine_spec.execute(cursor, last_sql, self) else: _log_query(sqls[-1]) - self.db_engine_spec.execute(cursor, sqls[-1], self.id) + self.db_engine_spec.execute(cursor, sqls[-1], self) data = self.db_engine_spec.fetch_data(cursor) result_set = SupersetResultSet( @@ -983,6 +989,26 @@ class Database( sqla_col.key = label_expected return sqla_col + def is_oauth2_enabled(self) -> bool: + """ + Is OAuth2 enabled in the database for authentication? + + Currently this looks for a global config at the DB engine spec level, but in the + future we want to be allow admins to create custom OAuth2 clients from the + Superset UI, and assign them to specific databases. + """ + return self.db_engine_spec.is_oauth2_enabled() + + def get_oauth2_config(self) -> OAuth2ClientConfig: + """ + Return OAuth2 client configuration. + + This includes client ID, client secret, scope, redirect URI, endpointsm etc. + Currently this reads the global DB engine spec config, but in the future it + should first check if there's a custom client assigned to the database. + """ + return self.db_engine_spec.get_oauth2_config() + sqla.event.listen(Database, "after_insert", security_manager.database_after_insert) sqla.event.listen(Database, "after_update", security_manager.database_after_update) diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py index 8c815ad63e..4852f70ee4 100644 --- a/superset/sql_validators/presto_db.py +++ b/superset/sql_validators/presto_db.py @@ -73,7 +73,7 @@ class PrestoDBSQLValidator(BaseSQLValidator): from pyhive.exc import DatabaseError try: - db_engine_spec.execute(cursor, sql, database.id) + db_engine_spec.execute(cursor, sql, database) polled = cursor.poll() while polled: logger.info("polling presto for validation progress") diff --git a/superset/superset_typing.py b/superset/superset_typing.py index c71dcea3f1..ba623f5819 100644 --- a/superset/superset_typing.py +++ b/superset/superset_typing.py @@ -121,3 +121,52 @@ FlaskResponse = Union[ tuple[Base, Status, Headers], tuple[Response, Status], ] + + +class OAuth2ClientConfig(TypedDict): + """ + Configuration for an OAuth2 client. + """ + + # The client ID and secret. + id: str + secret: str + + # The scopes requested; this is usually a space separated list of URLs. + scope: str + + # The URI where the user is redirected to after authorizing the client; by default + # this points to `/api/v1/databases/oauth2/`, but it can be overridden by the admin. + redirect_uri: str + + # The URI used to getting a code. + authorization_request_uri: str + + # The URI used when exchaing the code for an access token, or when refreshing an + # expired access token. + token_request_uri: str + + +class OAuth2TokenResponse(TypedDict, total=False): + """ + Type for an OAuth2 response when exchanging or refreshing tokens. + """ + + access_token: str + expires_in: int + scope: str + token_type: str + + # only present when exchanging code for refresh/access tokens + refresh_token: str + + +class OAuth2State(TypedDict): + """ + Type for the state passed during OAuth2. + """ + + database_id: int + user_id: int + default_redirect_uri: str + tab_id: str diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py index 7e80df9599..9cc58a0b7f 100644 --- a/superset/utils/oauth2.py +++ b/superset/utils/oauth2.py @@ -26,11 +26,12 @@ from flask import current_app from marshmallow import EXCLUDE, fields, post_load, Schema from superset import db -from superset.db_engine_specs.base import BaseEngineSpec, OAuth2State from superset.exceptions import CreateKeyValueDistributedLockFailedException +from superset.superset_typing import OAuth2ClientConfig, OAuth2State from superset.utils.lock import KeyValueDistributedLock if TYPE_CHECKING: + from superset.db_engine_specs.base import BaseEngineSpec from superset.models.core import DatabaseUserOAuth2Tokens JWT_EXPIRATION = timedelta(minutes=5) @@ -44,6 +45,7 @@ JWT_EXPIRATION = timedelta(minutes=5) max_tries=5, ) def get_oauth2_access_token( + config: OAuth2ClientConfig, database_id: int, user_id: int, db_engine_spec: type[BaseEngineSpec], @@ -73,7 +75,7 @@ def get_oauth2_access_token( return token.access_token if token.refresh_token: - return refresh_oauth2_token(database_id, user_id, db_engine_spec, token) + return refresh_oauth2_token(config, database_id, user_id, db_engine_spec, token) # since the access token is expired and there's no refresh token, delete the entry db.session.delete(token) @@ -82,6 +84,7 @@ def get_oauth2_access_token( def refresh_oauth2_token( + config: OAuth2ClientConfig, database_id: int, user_id: int, db_engine_spec: type[BaseEngineSpec], @@ -92,7 +95,10 @@ def refresh_oauth2_token( user_id=user_id, database_id=database_id, ): - token_response = db_engine_spec.get_oauth2_fresh_token(token.refresh_token) + token_response = db_engine_spec.get_oauth2_fresh_token( + config, + token.refresh_token, + ) # store new access token; note that the refresh token might be revoked, in which # case there would be no access token in the response diff --git a/tests/unit_tests/db_engine_specs/test_gsheets.py b/tests/unit_tests/db_engine_specs/test_gsheets.py index eb72878a78..89ad88e972 100644 --- a/tests/unit_tests/db_engine_specs/test_gsheets.py +++ b/tests/unit_tests/db_engine_specs/test_gsheets.py @@ -451,7 +451,7 @@ def test_is_oauth2_enabled_no_config(mocker: MockFixture) -> None: mocker.patch( "superset.db_engine_specs.gsheets.current_app.config", - new={"DATABASE_OAUTH2_CREDENTIALS": {}}, + new={"DATABASE_OAUTH2_CLIENTS": {}}, ) assert GSheetsEngineSpec.is_oauth2_enabled() is False @@ -466,7 +466,7 @@ def test_is_oauth2_enabled_config(mocker: MockFixture) -> None: mocker.patch( "superset.db_engine_specs.gsheets.current_app.config", new={ - "DATABASE_OAUTH2_CREDENTIALS": { + "DATABASE_OAUTH2_CLIENTS": { "Google Sheets": { "CLIENT_ID": "XXX.apps.googleusercontent.com", "CLIENT_SECRET": "GOCSPX-YYY", @@ -487,7 +487,7 @@ def test_get_oauth2_authorization_uri(mocker: MockFixture) -> None: mocker.patch( "superset.db_engine_specs.gsheets.current_app.config", new={ - "DATABASE_OAUTH2_CREDENTIALS": { + "DATABASE_OAUTH2_CLIENTS": { "Google Sheets": { "CLIENT_ID": "XXX.apps.googleusercontent.com", "CLIENT_SECRET": "GOCSPX-YYY", @@ -540,7 +540,7 @@ def test_get_oauth2_token(mocker: MockFixture) -> None: mocker.patch( "superset.db_engine_specs.gsheets.current_app.config", new={ - "DATABASE_OAUTH2_CREDENTIALS": { + "DATABASE_OAUTH2_CLIENTS": { "Google Sheets": { "CLIENT_ID": "XXX.apps.googleusercontent.com", "CLIENT_SECRET": "GOCSPX-YYY", @@ -598,7 +598,7 @@ def test_get_oauth2_fresh_token(mocker: MockFixture) -> None: mocker.patch( "superset.db_engine_specs.gsheets.current_app.config", new={ - "DATABASE_OAUTH2_CREDENTIALS": { + "DATABASE_OAUTH2_CLIENTS": { "Google Sheets": { "CLIENT_ID": "XXX.apps.googleusercontent.com", "CLIENT_SECRET": "GOCSPX-YYY",
