This is an automated email from the ASF dual-hosted git repository. beto pushed a commit to branch sip-85-b in repository https://gitbox.apache.org/repos/asf/superset.git
commit 8d49afd12e233b27ba9486ac0996318dfb21a5c8 Author: Beto Dealmeida <[email protected]> AuthorDate: Wed Apr 3 10:12:51 2024 -0400 WIP --- superset/config.py | 20 ++- superset/connectors/sqla/utils.py | 2 +- superset/db_engine_specs/README.md | 8 +- superset/db_engine_specs/base.py | 174 ++++++++++++++++------- superset/db_engine_specs/gsheets.py | 89 +----------- superset/db_engine_specs/hive.py | 1 + superset/db_engine_specs/presto.py | 2 + superset/db_engine_specs/trino.py | 2 + superset/models/core.py | 36 ++++- superset/sql_validators/presto_db.py | 2 +- superset/superset_typing.py | 49 +++++++ superset/utils/oauth2.py | 12 +- tests/unit_tests/db_engine_specs/test_gsheets.py | 10 +- 13 files changed, 250 insertions(+), 157 deletions(-) diff --git a/superset/config.py b/superset/config.py index b6dbfc9fee..bb5b04b232 100644 --- a/superset/config.py +++ b/superset/config.py @@ -1409,12 +1409,20 @@ TEST_DATABASE_CONNECTION_TIMEOUT = timedelta(seconds=30) # Details needed for databases that allows user to authenticate using personal # OAuth2 tokens. See https://github.com/apache/superset/issues/20300 for more -# information -DATABASE_OAUTH2_CREDENTIALS: dict[str, dict[str, Any]] = { +# information. The scope and URIs are optional. +DATABASE_OAUTH2_CLIENTS: dict[str, dict[str, Any]] = { # "Google Sheets": { - # "CLIENT_ID": "XXX.apps.googleusercontent.com", - # "CLIENT_SECRET": "GOCSPX-YYY", - # "BASEURL": "https://accounts.google.com/o/oauth2/v2/auth", + # "id": "XXX.apps.googleusercontent.com", + # "secret": "GOCSPX-YYY", + # "scope": " ".join( + # [ + # "https://www.googleapis.com/auth/drive.readonly", + # "https://www.googleapis.com/auth/spreadsheets", + # "https://spreadsheets.google.com/feeds", + # ] + # ), + # "authorization_request_uri": "https://accounts.google.com/o/oauth2/v2/auth", + # "token_request_uri": "https://oauth2.googleapis.com/token", # }, } # OAuth2 state is encoded in a JWT using the alogorithm below. @@ -1425,6 +1433,8 @@ DATABASE_OAUTH2_JWT_ALGORITHM = "HS256" # applications. In that case, the proxy can forward the request to the correct instance # by looking at the `default_redirect_uri` attribute in the OAuth2 state object. # DATABASE_OAUTH2_REDIRECT_URI = "http://localhost:8088/api/v1/database/oauth2/" +# Timeout when fetching access and refresh tokens. +DATABASE_OAUTH2_TIMEOUT = timedelta(seconds=30) # Enable/disable CSP warning CONTENT_SECURITY_POLICY_WARNING = True diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py index d0922e40f3..4bc11aee42 100644 --- a/superset/connectors/sqla/utils.py +++ b/superset/connectors/sqla/utils.py @@ -145,7 +145,7 @@ def get_columns_description( cursor = conn.cursor() query = database.apply_limit_to_sql(query, limit=1) cursor.execute(query) - db_engine_spec.execute(cursor, query, database.id) + db_engine_spec.execute(cursor, query, database) result = db_engine_spec.fetch_data(cursor, limit=1) result_set = SupersetResultSet(result, cursor.description, db_engine_spec) return result_set.columns diff --git a/superset/db_engine_specs/README.md b/superset/db_engine_specs/README.md index 0be1f29146..175840c608 100644 --- a/superset/db_engine_specs/README.md +++ b/superset/db_engine_specs/README.md @@ -582,17 +582,17 @@ from flask import current_app class GSheetsEngineSpec(ShillelaghEngineSpec): @staticmethod def is_oauth2_enabled() -> bool: - return "Google Sheets" in current_app.config["DATABASE_OAUTH2_CREDENTIALS"] + return "Google Sheets" in current_app.config["DATABASE_OAUTH2_CLIENTS"] ``` Where the configuration for OAuth2 would look like this: ```python # superset_config.py -DATABASE_OAUTH2_CREDENTIALS = { +DATABASE_OAUTH2_CLIENTS = { "Google Sheets": { - "CLIENT_ID": "XXX.apps.googleusercontent.com", - "CLIENT_SECRET": "GOCSPX-YYY", + "id": "XXX.apps.googleusercontent.com", + "secret": "GOCSPX-YYY", }, } DATABASE_OAUTH2_JWT_ALGORITHM = "HS256" diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 12797fc6a3..8fdffe2095 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -33,9 +33,11 @@ from typing import ( TypedDict, Union, ) +from urllib.parse import urlencode, urljoin from uuid import uuid4 import pandas as pd +import requests import sqlparse from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin @@ -60,13 +62,20 @@ from superset import security_manager, sql_parse from superset.constants import TimeGrain as TimeGrainConstants from superset.databases.utils import make_url_safe from superset.errors import ErrorLevel, SupersetError, SupersetErrorType -from superset.exceptions import OAuth2Error, OAuth2RedirectError +from superset.exceptions import OAuth2RedirectError from superset.sql_parse import ParsedQuery, SQLScript, Table -from superset.superset_typing import ResultSetColumnType, SQLAColumnType +from superset.superset_typing import ( + OAuth2ClientConfig, + OAuth2State, + OAuth2TokenResponse, + ResultSetColumnType, + SQLAColumnType, +) from superset.utils import core as utils from superset.utils.core import ColumnSpec, GenericDataType from superset.utils.hashing import md5_sha_from_str from superset.utils.network import is_hostname_valid, is_port_open +from superset.utils.oauth2 import encode_oauth2_state if TYPE_CHECKING: from superset.connectors.sqla.models import TableColumn @@ -173,31 +182,6 @@ class MetricType(TypedDict, total=False): extra: str | None -class OAuth2TokenResponse(TypedDict, total=False): - """ - Type for an OAuth2 response when exchanging or refreshing tokens. - """ - - access_token: str - expires_in: int - scope: str - token_type: str - - # only present when exchanging code for refresh/access tokens - refresh_token: str - - -class OAuth2State(TypedDict): - """ - Type for the state passed during OAuth2. - """ - - database_id: int - user_id: int - default_redirect_uri: str - tab_id: str - - class BaseEngineSpec: # pylint: disable=too-many-public-methods """Abstract class for database engine specific configurations @@ -425,15 +409,25 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # Can the catalog be changed on a per-query basis? supports_dynamic_catalog = False + # Does the engine supports OAuth 2.0? This requires logic to be added to one of the + # the user impersonation methods to handle personal tokens. + supports_oauth2 = False + oauth2_scope = "" + oauth2_authorization_request_uri = "" + oauth2_token_request_uri = "" + # Driver-specific exception that should be mapped to OAuth2RedirectError - oauth2_exception = OAuth2RedirectError + oauth2_exception: Exception = OAuth2RedirectError - @staticmethod - def is_oauth2_enabled() -> bool: - return False + @classmethod + def is_oauth2_enabled(cls) -> bool: + return ( + cls.supports_oauth2 + and cls.engine_name in current_app.config["DATABASE_OAUTH2_CLIENTS"] + ) @classmethod - def start_oauth2_dance(cls, database_id: int) -> None: + def start_oauth2_dance(cls, database: "Database") -> None: """ Start the OAuth2 dance. @@ -446,10 +440,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods """ tab_id = str(uuid4()) default_redirect_uri = url_for("DatabaseRestApi.oauth2", _external=True) - redirect_uri = current_app.config.get( - "DATABASE_OAUTH2_REDIRECT_URI", - default_redirect_uri, - ) # The state is passed to the OAuth2 provider, and sent back to Superset after # the user authorizes the access. The redirect endpoint in Superset can then @@ -457,7 +447,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # belongs to. state: OAuth2State = { # Database ID and user ID are the primary key associated with the token. - "database_id": database_id, + "database_id": database.id, "user_id": g.user.id, # In multi-instance deployments there might be a single proxy handling # redirects, with a custom `DATABASE_OAUTH2_REDIRECT_URI`. Since the OAuth2 @@ -473,30 +463,112 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # message. "tab_id": tab_id, } - oauth_url = cls.get_oauth2_authorization_uri(state) + config = database.get_oauth2_config() + oauth_url = cls.get_oauth2_authorization_uri(config, state) - raise OAuth2RedirectError(oauth_url, tab_id, redirect_uri) + raise OAuth2RedirectError(oauth_url, tab_id, default_redirect_uri) - @staticmethod - def get_oauth2_authorization_uri(state: OAuth2State) -> str: + @classmethod + def get_oauth2_config(cls) -> OAuth2ClientConfig: + """ + Build the OAuth2 client config. + + Currently this is built based on the `config.py` file, since clients are defined + at the DB engine spec level. In the future we'll allow admins to create custom + OAuth2 clients and assign them to specific databases. + """ + redirect_uri = current_app.config.get( + "DATABASE_OAUTH2_REDIRECT_URI", + url_for("DatabaseRestApi.oauth2", _external=True), + ) + + oauth2_config = current_app.config["DATABASE_OAUTH2_CLIENTS"] + try: + db_engine_spec_config = oauth2_config[cls.engine_name] + except KeyError: + raise Exception("TODO") + + config: OAuth2ClientConfig = { + "id": db_engine_spec_config["id"], + "secret": db_engine_spec_config["secret"], + "scope": db_engine_spec_config.get("scope") or cls.oauth2_scope, + "redirect_uri": redirect_uri, + "authorization_request_uri": db_engine_spec_config.get( + "authorization_request_uri" + ) + or cls.oauth2_authorization_request_uri, + "token_request_uri": db_engine_spec_config.get("token_request_uri") + or cls.oauth2_token_request_uri, + } + + return config + + @classmethod + def get_oauth2_authorization_uri( + cls, + config: OAuth2ClientConfig, + state: OAuth2State, + ) -> str: """ Return URI for initial OAuth2 request. """ - raise OAuth2Error("Subclasses must implement `get_oauth2_authorization_uri`") + uri = config["authorization_request_uri"] + params = { + "scope": config["scope"], + "access_type": "offline", + "include_granted_scopes": "false", + "response_type": "code", + "state": encode_oauth2_state(state), + "redirect_uri": config["redirect_uri"], + "client_id": config["id"], + "prompt": "consent", + } + return urljoin(uri, "?" + urlencode(params)) - @staticmethod - def get_oauth2_token(code: str, state: OAuth2State) -> OAuth2TokenResponse: + @classmethod + def get_oauth2_token( + config: OAuth2ClientConfig, + code: str, + ) -> OAuth2TokenResponse: """ Exchange authorization code for refresh/access tokens. """ - raise OAuth2Error("Subclasses must implement `get_oauth2_token`") + timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds() + uri = config["token_request_uri"] + response = requests.post( + uri, + fields={ + "code": code, + "client_id": config["id"], + "client_secret": config["secret"], + "redirect_uri": config["redirect_uri"], + "grant_type": "authorization_code", + }, + timeout=timeout, + ) + return response.json() @staticmethod - def get_oauth2_fresh_token(refresh_token: str) -> OAuth2TokenResponse: + def get_oauth2_fresh_token( + config: OAuth2ClientConfig, + refresh_token: str, + ) -> OAuth2TokenResponse: """ Refresh an access token that has expired. """ - raise OAuth2Error("Subclasses must implement `get_oauth2_fresh_token`") + timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds() + uri = config["token_request_uri"] + response = requests.post( + uri, + fields={ + "client_id": config["id"], + "client_secret": config["secret"], + "refresh_token": refresh_token, + "grant_type": "refresh_token", + }, + timeout=timeout, + ) + return response.json() @classmethod def get_allows_alias_in_select( @@ -1667,6 +1739,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods connect_args: dict[str, Any], uri: str, username: str | None, + access_token: str | None, ) -> None: """ Update a configuration dictionary @@ -1675,6 +1748,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods :param connect_args: config to be updated :param uri: URI :param username: Effective username + :param access_token: Personal access token for OAuth2 :return: None """ @@ -1683,7 +1757,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods cls, cursor: Any, query: str, - database_id: int, + database: "Database", **kwargs: Any, ) -> None: """ @@ -1703,8 +1777,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods try: cursor.execute(query) except cls.oauth2_exception as ex: - if cls.is_oauth2_enabled() and g.user: - cls.start_oauth2_dance(database_id) + if database.is_oauth2_enabled() and g.user: + cls.start_oauth2_dance(database) raise cls.get_dbapi_mapped_exception(ex) from ex except Exception as ex: raise cls.get_dbapi_mapped_exception(ex) from ex diff --git a/superset/db_engine_specs/gsheets.py b/superset/db_engine_specs/gsheets.py index 28e95811d4..f74982fe4b 100644 --- a/superset/db_engine_specs/gsheets.py +++ b/superset/db_engine_specs/gsheets.py @@ -23,13 +23,11 @@ import logging import re from re import Pattern from typing import Any, TYPE_CHECKING, TypedDict -from urllib.parse import urlencode, urljoin import pandas as pd -import urllib3 from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin -from flask import current_app, g +from flask import g from flask_babel import gettext as __ from marshmallow import fields, Schema from marshmallow.exceptions import ValidationError @@ -42,11 +40,9 @@ from sqlalchemy.engine.url import URL from superset import db, security_manager from superset.constants import PASSWORD_MASK from superset.databases.schemas import encrypted_field_properties, EncryptedString -from superset.db_engine_specs.base import OAuth2State, OAuth2TokenResponse from superset.db_engine_specs.shillelagh import ShillelaghEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import SupersetException -from superset.utils.oauth2 import encode_oauth2_state if TYPE_CHECKING: from superset.models.core import Database @@ -62,7 +58,6 @@ EXAMPLE_GSHEETS_URL = ( SYNTAX_ERROR_REGEX = re.compile('SQLError: near "(?P<server_error>.*?)": syntax error') ma_plugin = MarshmallowPlugin() -http = urllib3.PoolManager() class GSheetsParametersSchema(Schema): @@ -111,7 +106,11 @@ class GSheetsEngineSpec(ShillelaghEngineSpec): supports_file_upload = True - # exception raised by shillelagh that should trigger OAuth2 + # OAuth 2.0 + supports_oauth2 = True + oauth2_scope = " ".join(SCOPES) + oauth2_authorization_request_uri = "https://accounts.google.com/o/oauth2/v2/auth" + oauth2_token_request_uri = "https://oauth2.googleapis.com/token" oauth2_exception = UnauthenticatedError @classmethod @@ -153,82 +152,6 @@ class GSheetsEngineSpec(ShillelaghEngineSpec): return {"metadata": metadata["extra"]} - @staticmethod - def is_oauth2_enabled() -> bool: - """ - Return if OAuth2 is enabled for GSheets. - """ - return "Google Sheets" in current_app.config["DATABASE_OAUTH2_CREDENTIALS"] - - @classmethod - def get_oauth2_authorization_uri(cls, state: OAuth2State) -> str: - """ - Return URI for initial OAuth2 request. - - https://developers.google.com/identity/protocols/oauth2/web-server#creatingclient - """ - config = current_app.config["DATABASE_OAUTH2_CREDENTIALS"]["Google Sheets"] - baseurl = config.get("BASEURL", "https://accounts.google.com/o/oauth2/v2/auth") - redirect_uri = current_app.config.get( - "DATABASE_OAUTH2_REDIRECT_URI", - state["default_redirect_uri"], - ) - - params = { - "scope": " ".join(SCOPES), - "access_type": "offline", - "include_granted_scopes": "false", - "response_type": "code", - "state": encode_oauth2_state(state), - "redirect_uri": redirect_uri, - "client_id": config["CLIENT_ID"], - "prompt": "consent", - } - return urljoin(baseurl, "?" + urlencode(params)) - - @staticmethod - def get_oauth2_token(code: str, state: OAuth2State) -> OAuth2TokenResponse: - """ - Exchange authorization code for refresh/access tokens. - """ - config = current_app.config["DATABASE_OAUTH2_CREDENTIALS"]["Google Sheets"] - redirect_uri = current_app.config.get( - "DATABASE_OAUTH2_REDIRECT_URI", - state["default_redirect_uri"], - ) - - response = http.request( - "POST", - "https://oauth2.googleapis.com/token", - fields={ - "code": code, - "client_id": config["CLIENT_ID"], - "client_secret": config["CLIENT_SECRET"], - "redirect_uri": redirect_uri, - "grant_type": "authorization_code", - }, - ) - return json.loads(response.data.decode("utf-8")) - - @staticmethod - def get_oauth2_fresh_token(refresh_token: str) -> OAuth2TokenResponse: - """ - Refresh an access token that has expired. - """ - config = current_app.config["DATABASE_OAUTH2_CREDENTIALS"]["Google Sheets"] - - response = http.request( - "POST", - "https://oauth2.googleapis.com/token", - fields={ - "client_id": config["CLIENT_ID"], - "client_secret": config["CLIENT_SECRET"], - "refresh_token": refresh_token, - "grant_type": "refresh_token", - }, - ) - return json.loads(response.data.decode("utf-8")) - @classmethod def build_sqlalchemy_uri( cls, diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index a97dd88aef..4b4c2d951b 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -528,6 +528,7 @@ class HiveEngineSpec(PrestoEngineSpec): connect_args: dict[str, Any], uri: str, username: str | None, + access_token: str | None, ) -> None: """ Update a configuration dictionary diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 0df8d53f4f..ab44edc258 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -717,6 +717,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): connect_args: dict[str, Any], uri: str, username: str | None, + access_token: str | None, ) -> None: """ Update a configuration dictionary @@ -724,6 +725,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): :param connect_args: config to be updated :param uri: URI string :param username: Effective username + :param access_token: Personal access token for OAuth2 :return: None """ url = make_url_safe(uri) diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 4513d63c60..4eeac72a7b 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -112,6 +112,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec): connect_args: dict[str, Any], uri: str, username: str | None, + access_token: str | None, ) -> None: """ Update a configuration dictionary @@ -119,6 +120,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec): :param connect_args: config to be updated :param uri: URI string :param username: Effective username + :param access_token: Personal access token for OAuth2 :return: None """ url = make_url_safe(uri) diff --git a/superset/models/core.py b/superset/models/core.py index cf84b90ac2..eb165e4dfa 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -71,7 +71,7 @@ from superset.extensions import ( ) from superset.models.helpers import AuditMixinNullable, ImportExportMixin from superset.result_set import SupersetResultSet -from superset.superset_typing import ResultSetColumnType +from superset.superset_typing import OAuth2ClientConfig, ResultSetColumnType from superset.utils import cache as cache_util, core as utils from superset.utils.backports import StrEnum from superset.utils.core import get_username @@ -467,7 +467,12 @@ class Database( effective_username = self.get_effective_user(sqlalchemy_url) access_token = ( - get_oauth2_access_token(self.id, g.user.id, self.db_engine_spec) + get_oauth2_access_token( + self.get_oauth2_config(), + self.id, + g.user.id, + self.db_engine_spec, + ) if hasattr(g, "user") and hasattr(g.user, "id") else None ) @@ -489,6 +494,7 @@ class Database( connect_args, str(sqlalchemy_url), effective_username, + access_token, ) if connect_args: @@ -599,7 +605,7 @@ class Database( database=None, ) _log_query(sql_) - self.db_engine_spec.execute(cursor, sql_, self.id) + self.db_engine_spec.execute(cursor, sql_, self) cursor.fetchall() if mutate_after_split: @@ -609,10 +615,10 @@ class Database( database=None, ) _log_query(last_sql) - self.db_engine_spec.execute(cursor, last_sql, self.id) + self.db_engine_spec.execute(cursor, last_sql, self) else: _log_query(sqls[-1]) - self.db_engine_spec.execute(cursor, sqls[-1], self.id) + self.db_engine_spec.execute(cursor, sqls[-1], self) data = self.db_engine_spec.fetch_data(cursor) result_set = SupersetResultSet( @@ -983,6 +989,26 @@ class Database( sqla_col.key = label_expected return sqla_col + def is_oauth2_enabled(self) -> bool: + """ + Is OAuth2 enabled in the database for authentication? + + Currently this looks for a global config at the DB engine spec level, but in the + future we want to be allow admins to create custom OAuth2 clients from the + Superset UI, and assign them to specific databases. + """ + return self.db_engine_spec.is_oauth2_enabled() + + def get_oauth2_config(self) -> OAuth2ClientConfig: + """ + Return OAuth2 client configuration. + + This includes client ID, client secret, scope, redirect URI, endpointsm etc. + Currently this reads the global DB engine spec config, but in the future it + should first check if there's a custom client assigned to the database. + """ + return self.db_engine_spec.get_oauth2_config() + sqla.event.listen(Database, "after_insert", security_manager.database_after_insert) sqla.event.listen(Database, "after_update", security_manager.database_after_update) diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py index 8c815ad63e..4852f70ee4 100644 --- a/superset/sql_validators/presto_db.py +++ b/superset/sql_validators/presto_db.py @@ -73,7 +73,7 @@ class PrestoDBSQLValidator(BaseSQLValidator): from pyhive.exc import DatabaseError try: - db_engine_spec.execute(cursor, sql, database.id) + db_engine_spec.execute(cursor, sql, database) polled = cursor.poll() while polled: logger.info("polling presto for validation progress") diff --git a/superset/superset_typing.py b/superset/superset_typing.py index c71dcea3f1..ba623f5819 100644 --- a/superset/superset_typing.py +++ b/superset/superset_typing.py @@ -121,3 +121,52 @@ FlaskResponse = Union[ tuple[Base, Status, Headers], tuple[Response, Status], ] + + +class OAuth2ClientConfig(TypedDict): + """ + Configuration for an OAuth2 client. + """ + + # The client ID and secret. + id: str + secret: str + + # The scopes requested; this is usually a space separated list of URLs. + scope: str + + # The URI where the user is redirected to after authorizing the client; by default + # this points to `/api/v1/databases/oauth2/`, but it can be overridden by the admin. + redirect_uri: str + + # The URI used to getting a code. + authorization_request_uri: str + + # The URI used when exchaing the code for an access token, or when refreshing an + # expired access token. + token_request_uri: str + + +class OAuth2TokenResponse(TypedDict, total=False): + """ + Type for an OAuth2 response when exchanging or refreshing tokens. + """ + + access_token: str + expires_in: int + scope: str + token_type: str + + # only present when exchanging code for refresh/access tokens + refresh_token: str + + +class OAuth2State(TypedDict): + """ + Type for the state passed during OAuth2. + """ + + database_id: int + user_id: int + default_redirect_uri: str + tab_id: str diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py index 7e80df9599..9cc58a0b7f 100644 --- a/superset/utils/oauth2.py +++ b/superset/utils/oauth2.py @@ -26,11 +26,12 @@ from flask import current_app from marshmallow import EXCLUDE, fields, post_load, Schema from superset import db -from superset.db_engine_specs.base import BaseEngineSpec, OAuth2State from superset.exceptions import CreateKeyValueDistributedLockFailedException +from superset.superset_typing import OAuth2ClientConfig, OAuth2State from superset.utils.lock import KeyValueDistributedLock if TYPE_CHECKING: + from superset.db_engine_specs.base import BaseEngineSpec from superset.models.core import DatabaseUserOAuth2Tokens JWT_EXPIRATION = timedelta(minutes=5) @@ -44,6 +45,7 @@ JWT_EXPIRATION = timedelta(minutes=5) max_tries=5, ) def get_oauth2_access_token( + config: OAuth2ClientConfig, database_id: int, user_id: int, db_engine_spec: type[BaseEngineSpec], @@ -73,7 +75,7 @@ def get_oauth2_access_token( return token.access_token if token.refresh_token: - return refresh_oauth2_token(database_id, user_id, db_engine_spec, token) + return refresh_oauth2_token(config, database_id, user_id, db_engine_spec, token) # since the access token is expired and there's no refresh token, delete the entry db.session.delete(token) @@ -82,6 +84,7 @@ def get_oauth2_access_token( def refresh_oauth2_token( + config: OAuth2ClientConfig, database_id: int, user_id: int, db_engine_spec: type[BaseEngineSpec], @@ -92,7 +95,10 @@ def refresh_oauth2_token( user_id=user_id, database_id=database_id, ): - token_response = db_engine_spec.get_oauth2_fresh_token(token.refresh_token) + token_response = db_engine_spec.get_oauth2_fresh_token( + config, + token.refresh_token, + ) # store new access token; note that the refresh token might be revoked, in which # case there would be no access token in the response diff --git a/tests/unit_tests/db_engine_specs/test_gsheets.py b/tests/unit_tests/db_engine_specs/test_gsheets.py index eb72878a78..89ad88e972 100644 --- a/tests/unit_tests/db_engine_specs/test_gsheets.py +++ b/tests/unit_tests/db_engine_specs/test_gsheets.py @@ -451,7 +451,7 @@ def test_is_oauth2_enabled_no_config(mocker: MockFixture) -> None: mocker.patch( "superset.db_engine_specs.gsheets.current_app.config", - new={"DATABASE_OAUTH2_CREDENTIALS": {}}, + new={"DATABASE_OAUTH2_CLIENTS": {}}, ) assert GSheetsEngineSpec.is_oauth2_enabled() is False @@ -466,7 +466,7 @@ def test_is_oauth2_enabled_config(mocker: MockFixture) -> None: mocker.patch( "superset.db_engine_specs.gsheets.current_app.config", new={ - "DATABASE_OAUTH2_CREDENTIALS": { + "DATABASE_OAUTH2_CLIENTS": { "Google Sheets": { "CLIENT_ID": "XXX.apps.googleusercontent.com", "CLIENT_SECRET": "GOCSPX-YYY", @@ -487,7 +487,7 @@ def test_get_oauth2_authorization_uri(mocker: MockFixture) -> None: mocker.patch( "superset.db_engine_specs.gsheets.current_app.config", new={ - "DATABASE_OAUTH2_CREDENTIALS": { + "DATABASE_OAUTH2_CLIENTS": { "Google Sheets": { "CLIENT_ID": "XXX.apps.googleusercontent.com", "CLIENT_SECRET": "GOCSPX-YYY", @@ -540,7 +540,7 @@ def test_get_oauth2_token(mocker: MockFixture) -> None: mocker.patch( "superset.db_engine_specs.gsheets.current_app.config", new={ - "DATABASE_OAUTH2_CREDENTIALS": { + "DATABASE_OAUTH2_CLIENTS": { "Google Sheets": { "CLIENT_ID": "XXX.apps.googleusercontent.com", "CLIENT_SECRET": "GOCSPX-YYY", @@ -598,7 +598,7 @@ def test_get_oauth2_fresh_token(mocker: MockFixture) -> None: mocker.patch( "superset.db_engine_specs.gsheets.current_app.config", new={ - "DATABASE_OAUTH2_CREDENTIALS": { + "DATABASE_OAUTH2_CLIENTS": { "Google Sheets": { "CLIENT_ID": "XXX.apps.googleusercontent.com", "CLIENT_SECRET": "GOCSPX-YYY",
