This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch sip-85-b
in repository https://gitbox.apache.org/repos/asf/superset.git

commit d08859212cae98c27152e20b7a125eb637aeb4c1
Author: Beto Dealmeida <[email protected]>
AuthorDate: Wed Apr 3 10:12:51 2024 -0400

    WIP
---
 superset/config.py                               |  20 ++-
 superset/connectors/sqla/utils.py                |   2 +-
 superset/db_engine_specs/README.md               |  62 ++++----
 superset/db_engine_specs/base.py                 | 174 ++++++++++++++++-------
 superset/db_engine_specs/gsheets.py              |  89 +-----------
 superset/db_engine_specs/hive.py                 |   1 +
 superset/db_engine_specs/presto.py               |   2 +
 superset/db_engine_specs/trino.py                |   2 +
 superset/models/core.py                          |  36 ++++-
 superset/sql_validators/presto_db.py             |   2 +-
 superset/superset_typing.py                      |  49 +++++++
 superset/utils/oauth2.py                         |  12 +-
 tests/unit_tests/db_engine_specs/test_gsheets.py |  10 +-
 13 files changed, 271 insertions(+), 190 deletions(-)

diff --git a/superset/config.py b/superset/config.py
index b6dbfc9fee..bb5b04b232 100644
--- a/superset/config.py
+++ b/superset/config.py
@@ -1409,12 +1409,20 @@ TEST_DATABASE_CONNECTION_TIMEOUT = timedelta(seconds=30)
 
 # Details needed for databases that allows user to authenticate using personal
 # OAuth2 tokens. See https://github.com/apache/superset/issues/20300 for more
-# information
-DATABASE_OAUTH2_CREDENTIALS: dict[str, dict[str, Any]] = {
+# information. The scope and URIs are optional.
+DATABASE_OAUTH2_CLIENTS: dict[str, dict[str, Any]] = {
     # "Google Sheets": {
-    #    "CLIENT_ID": "XXX.apps.googleusercontent.com",
-    #    "CLIENT_SECRET": "GOCSPX-YYY",
-    #    "BASEURL": "https://accounts.google.com/o/oauth2/v2/auth";,
+    #     "id": "XXX.apps.googleusercontent.com",
+    #     "secret": "GOCSPX-YYY",
+    #     "scope": " ".join(
+    #         [
+    #             "https://www.googleapis.com/auth/drive.readonly";,
+    #             "https://www.googleapis.com/auth/spreadsheets";,
+    #             "https://spreadsheets.google.com/feeds";,
+    #         ]
+    #     ),
+    #     "authorization_request_uri": 
"https://accounts.google.com/o/oauth2/v2/auth";,
+    #     "token_request_uri": "https://oauth2.googleapis.com/token";,
     # },
 }
 # OAuth2 state is encoded in a JWT using the alogorithm below.
@@ -1425,6 +1433,8 @@ DATABASE_OAUTH2_JWT_ALGORITHM = "HS256"
 # applications. In that case, the proxy can forward the request to the correct 
instance
 # by looking at the `default_redirect_uri` attribute in the OAuth2 state 
object.
 # DATABASE_OAUTH2_REDIRECT_URI = 
"http://localhost:8088/api/v1/database/oauth2/";
+# Timeout when fetching access and refresh tokens.
+DATABASE_OAUTH2_TIMEOUT = timedelta(seconds=30)
 
 # Enable/disable CSP warning
 CONTENT_SECURITY_POLICY_WARNING = True
diff --git a/superset/connectors/sqla/utils.py 
b/superset/connectors/sqla/utils.py
index d0922e40f3..4bc11aee42 100644
--- a/superset/connectors/sqla/utils.py
+++ b/superset/connectors/sqla/utils.py
@@ -145,7 +145,7 @@ def get_columns_description(
             cursor = conn.cursor()
             query = database.apply_limit_to_sql(query, limit=1)
             cursor.execute(query)
-            db_engine_spec.execute(cursor, query, database.id)
+            db_engine_spec.execute(cursor, query, database)
             result = db_engine_spec.fetch_data(cursor, limit=1)
             result_set = SupersetResultSet(result, cursor.description, 
db_engine_spec)
             return result_set.columns
diff --git a/superset/db_engine_specs/README.md 
b/superset/db_engine_specs/README.md
index 0be1f29146..d5d0d6cb4a 100644
--- a/superset/db_engine_specs/README.md
+++ b/superset/db_engine_specs/README.md
@@ -547,65 +547,53 @@ Alternatively, it's also possible to impersonate users by 
implementing the `upda
 
 Support for authenticating to a database using personal OAuth2 access tokens 
was introduced in [SIP-85](https://github.com/apache/superset/issues/20300). 
The Google Sheets DB engine spec is the reference implementation.
 
-To add support for OAuth2 to a DB engine spec, the following attribute and 
methods are needed:
+To add support for OAuth2 to a DB engine spec, the following attributes are 
needed:
 
 ```python
 class BaseEngineSpec:
 
+    supports_oauth2 = True
     oauth2_exception = OAuth2RedirectError
 
-    @staticmethod
-    def is_oauth2_enabled() -> bool:
-        return False
-
-    @staticmethod
-    def get_oauth2_authorization_uri(state: OAuth2State) -> str:
-        raise NotImplementedError()
-
-    @staticmethod
-    def get_oauth2_token(code: str, state: OAuth2State) -> OAuth2TokenResponse:
-        raise NotImplementedError()
-
-    @staticmethod
-    def get_oauth2_fresh_token(refresh_token: str) -> OAuth2TokenResponse:
-        raise NotImplementedError()
+    oauth2_scope = " ".join([
+        "https://example.org/scope1";,
+        "https://example.org/scope2";,
+    ])
+    oauth2_authorization_request_uri = "https://example.org/authorize";
+    oauth2_token_request_uri = "https://example.org/token";
 ```
 
 The `oauth2_exception` is an exception that is raised by `cursor.execute` when 
OAuth2 is needed. This will start the OAuth2 dance when 
`BaseEngineSpec.execute` is called, by returning the custom error 
`OAUTH2_REDIRECT` to the frontend. If the database driver doesn't have a 
specific exception, it might be necessary to overload the `execute` method in 
the DB engine spec, so that the `BaseEngineSpec.start_oauth2_dance` method gets 
called whenever OAuth2 is needed.
 
-The first method, `is_oauth2_enabled`, is used to inform if the database 
supports OAuth2. This can be dynamic; for example, the Google Sheets DB engine 
spec checks if the Superset configuration has the necessary section:
-
-```python
-from flask import current_app
-
+The DB engine should implement logic in either `get_url_for_impersonation` or 
`update_impersonation_config` to update the connection with the personal access 
token. See the Google Sheets DB engine spec for a reference implementation.
 
-class GSheetsEngineSpec(ShillelaghEngineSpec):
-    @staticmethod
-    def is_oauth2_enabled() -> bool:
-        return "Google Sheets" in 
current_app.config["DATABASE_OAUTH2_CREDENTIALS"]
-```
-
-Where the configuration for OAuth2 would look like this:
+Currently OAuth2 needs to be configured at the DB engine spec level, ie, with 
one client for each DB engien spec. The configuration lives in 
`superset_config.py`:
 
 ```python
 # superset_config.py
-DATABASE_OAUTH2_CREDENTIALS = {
+DATABASE_OAUTH2_CLIENTS = {
     "Google Sheets": {
-        "CLIENT_ID": "XXX.apps.googleusercontent.com",
-        "CLIENT_SECRET": "GOCSPX-YYY",
+        "id": "XXX.apps.googleusercontent.com",
+        "secret": "GOCSPX-YYY",
+        "scope": " ".join(
+            [
+                "https://www.googleapis.com/auth/drive.readonly";,
+                "https://www.googleapis.com/auth/spreadsheets";,
+                "https://spreadsheets.google.com/feeds";,
+            ],
+        ),
+        "authorization_request_uri": 
"https://accounts.google.com/o/oauth2/v2/auth";,
+        "token_request_uri": "https://oauth2.googleapis.com/token";,
     },
 }
 DATABASE_OAUTH2_JWT_ALGORITHM = "HS256"
 DATABASE_OAUTH2_REDIRECT_URI = "http://localhost:8088/api/v1/database/oauth2/";
+DATABASE_OAUTH2_TIMEOUT = timedelta(seconds=30)
 ```
 
-The second method, `get_oauth2_authorization_uri`, is responsible for building 
the URL where the user is sent to initiate OAuth2. This method receives a 
`state`. The state is an encoded JWT that is passed to the OAuth2 provider, and 
is received unmodified when the user is redirected back to Superset. The 
default state contains the user ID and the database ID, so that Superset can 
know where to store the received OAuth2 tokens.
-
-Additionally, the state also contains a `tab_id`, which is a random UUID4 used 
as a shared secret for communication between browser tabs. When OAuth2 starts, 
Superset will open a new browser tab, where the user will grant permissions to 
Superset. When authentication is complete and successful this opened tab will 
send a message to the original tab, so that the original query can be re-run. 
The `tab_id` is sent by the opened tab and verified by the original tab to 
prevent malicious messag [...]
-
-State also contains a `defaul_redirect_uri`, which is the enpoint in Supeset 
that receives the tokens from the OAuth2 provider (`/api/v1/database/oauth2/`). 
The redirect URL can be overwritten in the config file via the 
`DATABASE_OAUTH2_REDIRECT_URI` parameter. This might be useful where you have 
multiple Superset instances. Since the OAuth2 provider requires the redirect 
URL to be registered a priori, it might be easier (or needed) to register a 
single URL for a proxy service; the proxy [...]
+When configuring a client only the ID and secret are required; the DB engine 
spec should have default values for the scope and endpoints. The 
`DATABASE_OAUTH2_REDIRECT_URI` attribute is optional, and defaults to 
`/api/v1/databases/oauth2/` in Superset.
 
-Finally, `get_oauth2_token` and `get_oauth2_fresh_token` are used to actually 
retrieve a token and refresh an expired token, respectively.
+In the future we plan to support adding custom clients via the Superset UI, 
and being able to manually assign clients to specific databases.
 
 ### File upload
 
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 12797fc6a3..8fdffe2095 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -33,9 +33,11 @@ from typing import (
     TypedDict,
     Union,
 )
+from urllib.parse import urlencode, urljoin
 from uuid import uuid4
 
 import pandas as pd
+import requests
 import sqlparse
 from apispec import APISpec
 from apispec.ext.marshmallow import MarshmallowPlugin
@@ -60,13 +62,20 @@ from superset import security_manager, sql_parse
 from superset.constants import TimeGrain as TimeGrainConstants
 from superset.databases.utils import make_url_safe
 from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
-from superset.exceptions import OAuth2Error, OAuth2RedirectError
+from superset.exceptions import OAuth2RedirectError
 from superset.sql_parse import ParsedQuery, SQLScript, Table
-from superset.superset_typing import ResultSetColumnType, SQLAColumnType
+from superset.superset_typing import (
+    OAuth2ClientConfig,
+    OAuth2State,
+    OAuth2TokenResponse,
+    ResultSetColumnType,
+    SQLAColumnType,
+)
 from superset.utils import core as utils
 from superset.utils.core import ColumnSpec, GenericDataType
 from superset.utils.hashing import md5_sha_from_str
 from superset.utils.network import is_hostname_valid, is_port_open
+from superset.utils.oauth2 import encode_oauth2_state
 
 if TYPE_CHECKING:
     from superset.connectors.sqla.models import TableColumn
@@ -173,31 +182,6 @@ class MetricType(TypedDict, total=False):
     extra: str | None
 
 
-class OAuth2TokenResponse(TypedDict, total=False):
-    """
-    Type for an OAuth2 response when exchanging or refreshing tokens.
-    """
-
-    access_token: str
-    expires_in: int
-    scope: str
-    token_type: str
-
-    # only present when exchanging code for refresh/access tokens
-    refresh_token: str
-
-
-class OAuth2State(TypedDict):
-    """
-    Type for the state passed during OAuth2.
-    """
-
-    database_id: int
-    user_id: int
-    default_redirect_uri: str
-    tab_id: str
-
-
 class BaseEngineSpec:  # pylint: disable=too-many-public-methods
     """Abstract class for database engine specific configurations
 
@@ -425,15 +409,25 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
     # Can the catalog be changed on a per-query basis?
     supports_dynamic_catalog = False
 
+    # Does the engine supports OAuth 2.0? This requires logic to be added to 
one of the
+    # the user impersonation methods to handle personal tokens.
+    supports_oauth2 = False
+    oauth2_scope = ""
+    oauth2_authorization_request_uri = ""
+    oauth2_token_request_uri = ""
+
     # Driver-specific exception that should be mapped to OAuth2RedirectError
-    oauth2_exception = OAuth2RedirectError
+    oauth2_exception: Exception = OAuth2RedirectError
 
-    @staticmethod
-    def is_oauth2_enabled() -> bool:
-        return False
+    @classmethod
+    def is_oauth2_enabled(cls) -> bool:
+        return (
+            cls.supports_oauth2
+            and cls.engine_name in 
current_app.config["DATABASE_OAUTH2_CLIENTS"]
+        )
 
     @classmethod
-    def start_oauth2_dance(cls, database_id: int) -> None:
+    def start_oauth2_dance(cls, database: "Database") -> None:
         """
         Start the OAuth2 dance.
 
@@ -446,10 +440,6 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         """
         tab_id = str(uuid4())
         default_redirect_uri = url_for("DatabaseRestApi.oauth2", 
_external=True)
-        redirect_uri = current_app.config.get(
-            "DATABASE_OAUTH2_REDIRECT_URI",
-            default_redirect_uri,
-        )
 
         # The state is passed to the OAuth2 provider, and sent back to 
Superset after
         # the user authorizes the access. The redirect endpoint in Superset 
can then
@@ -457,7 +447,7 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         # belongs to.
         state: OAuth2State = {
             # Database ID and user ID are the primary key associated with the 
token.
-            "database_id": database_id,
+            "database_id": database.id,
             "user_id": g.user.id,
             # In multi-instance deployments there might be a single proxy 
handling
             # redirects, with a custom `DATABASE_OAUTH2_REDIRECT_URI`. Since 
the OAuth2
@@ -473,30 +463,112 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
             # message.
             "tab_id": tab_id,
         }
-        oauth_url = cls.get_oauth2_authorization_uri(state)
+        config = database.get_oauth2_config()
+        oauth_url = cls.get_oauth2_authorization_uri(config, state)
 
-        raise OAuth2RedirectError(oauth_url, tab_id, redirect_uri)
+        raise OAuth2RedirectError(oauth_url, tab_id, default_redirect_uri)
 
-    @staticmethod
-    def get_oauth2_authorization_uri(state: OAuth2State) -> str:
+    @classmethod
+    def get_oauth2_config(cls) -> OAuth2ClientConfig:
+        """
+        Build the OAuth2 client config.
+
+        Currently this is built based on the `config.py` file, since clients 
are defined
+        at the DB engine spec level. In the future we'll allow admins to 
create custom
+        OAuth2 clients and assign them to specific databases.
+        """
+        redirect_uri = current_app.config.get(
+            "DATABASE_OAUTH2_REDIRECT_URI",
+            url_for("DatabaseRestApi.oauth2", _external=True),
+        )
+
+        oauth2_config = current_app.config["DATABASE_OAUTH2_CLIENTS"]
+        try:
+            db_engine_spec_config = oauth2_config[cls.engine_name]
+        except KeyError:
+            raise Exception("TODO")
+
+        config: OAuth2ClientConfig = {
+            "id": db_engine_spec_config["id"],
+            "secret": db_engine_spec_config["secret"],
+            "scope": db_engine_spec_config.get("scope") or cls.oauth2_scope,
+            "redirect_uri": redirect_uri,
+            "authorization_request_uri": db_engine_spec_config.get(
+                "authorization_request_uri"
+            )
+            or cls.oauth2_authorization_request_uri,
+            "token_request_uri": db_engine_spec_config.get("token_request_uri")
+            or cls.oauth2_token_request_uri,
+        }
+
+        return config
+
+    @classmethod
+    def get_oauth2_authorization_uri(
+        cls,
+        config: OAuth2ClientConfig,
+        state: OAuth2State,
+    ) -> str:
         """
         Return URI for initial OAuth2 request.
         """
-        raise OAuth2Error("Subclasses must implement 
`get_oauth2_authorization_uri`")
+        uri = config["authorization_request_uri"]
+        params = {
+            "scope": config["scope"],
+            "access_type": "offline",
+            "include_granted_scopes": "false",
+            "response_type": "code",
+            "state": encode_oauth2_state(state),
+            "redirect_uri": config["redirect_uri"],
+            "client_id": config["id"],
+            "prompt": "consent",
+        }
+        return urljoin(uri, "?" + urlencode(params))
 
-    @staticmethod
-    def get_oauth2_token(code: str, state: OAuth2State) -> OAuth2TokenResponse:
+    @classmethod
+    def get_oauth2_token(
+        config: OAuth2ClientConfig,
+        code: str,
+    ) -> OAuth2TokenResponse:
         """
         Exchange authorization code for refresh/access tokens.
         """
-        raise OAuth2Error("Subclasses must implement `get_oauth2_token`")
+        timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
+        uri = config["token_request_uri"]
+        response = requests.post(
+            uri,
+            fields={
+                "code": code,
+                "client_id": config["id"],
+                "client_secret": config["secret"],
+                "redirect_uri": config["redirect_uri"],
+                "grant_type": "authorization_code",
+            },
+            timeout=timeout,
+        )
+        return response.json()
 
     @staticmethod
-    def get_oauth2_fresh_token(refresh_token: str) -> OAuth2TokenResponse:
+    def get_oauth2_fresh_token(
+        config: OAuth2ClientConfig,
+        refresh_token: str,
+    ) -> OAuth2TokenResponse:
         """
         Refresh an access token that has expired.
         """
-        raise OAuth2Error("Subclasses must implement `get_oauth2_fresh_token`")
+        timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
+        uri = config["token_request_uri"]
+        response = requests.post(
+            uri,
+            fields={
+                "client_id": config["id"],
+                "client_secret": config["secret"],
+                "refresh_token": refresh_token,
+                "grant_type": "refresh_token",
+            },
+            timeout=timeout,
+        )
+        return response.json()
 
     @classmethod
     def get_allows_alias_in_select(
@@ -1667,6 +1739,7 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         connect_args: dict[str, Any],
         uri: str,
         username: str | None,
+        access_token: str | None,
     ) -> None:
         """
         Update a configuration dictionary
@@ -1675,6 +1748,7 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         :param connect_args: config to be updated
         :param uri: URI
         :param username: Effective username
+        :param access_token: Personal access token for OAuth2
         :return: None
         """
 
@@ -1683,7 +1757,7 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         cls,
         cursor: Any,
         query: str,
-        database_id: int,
+        database: "Database",
         **kwargs: Any,
     ) -> None:
         """
@@ -1703,8 +1777,8 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         try:
             cursor.execute(query)
         except cls.oauth2_exception as ex:
-            if cls.is_oauth2_enabled() and g.user:
-                cls.start_oauth2_dance(database_id)
+            if database.is_oauth2_enabled() and g.user:
+                cls.start_oauth2_dance(database)
             raise cls.get_dbapi_mapped_exception(ex) from ex
         except Exception as ex:
             raise cls.get_dbapi_mapped_exception(ex) from ex
diff --git a/superset/db_engine_specs/gsheets.py 
b/superset/db_engine_specs/gsheets.py
index 28e95811d4..f74982fe4b 100644
--- a/superset/db_engine_specs/gsheets.py
+++ b/superset/db_engine_specs/gsheets.py
@@ -23,13 +23,11 @@ import logging
 import re
 from re import Pattern
 from typing import Any, TYPE_CHECKING, TypedDict
-from urllib.parse import urlencode, urljoin
 
 import pandas as pd
-import urllib3
 from apispec import APISpec
 from apispec.ext.marshmallow import MarshmallowPlugin
-from flask import current_app, g
+from flask import g
 from flask_babel import gettext as __
 from marshmallow import fields, Schema
 from marshmallow.exceptions import ValidationError
@@ -42,11 +40,9 @@ from sqlalchemy.engine.url import URL
 from superset import db, security_manager
 from superset.constants import PASSWORD_MASK
 from superset.databases.schemas import encrypted_field_properties, 
EncryptedString
-from superset.db_engine_specs.base import OAuth2State, OAuth2TokenResponse
 from superset.db_engine_specs.shillelagh import ShillelaghEngineSpec
 from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
 from superset.exceptions import SupersetException
-from superset.utils.oauth2 import encode_oauth2_state
 
 if TYPE_CHECKING:
     from superset.models.core import Database
@@ -62,7 +58,6 @@ EXAMPLE_GSHEETS_URL = (
 SYNTAX_ERROR_REGEX = re.compile('SQLError: near "(?P<server_error>.*?)": 
syntax error')
 
 ma_plugin = MarshmallowPlugin()
-http = urllib3.PoolManager()
 
 
 class GSheetsParametersSchema(Schema):
@@ -111,7 +106,11 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
 
     supports_file_upload = True
 
-    # exception raised by shillelagh that should trigger OAuth2
+    # OAuth 2.0
+    supports_oauth2 = True
+    oauth2_scope = " ".join(SCOPES)
+    oauth2_authorization_request_uri = 
"https://accounts.google.com/o/oauth2/v2/auth";
+    oauth2_token_request_uri = "https://oauth2.googleapis.com/token";
     oauth2_exception = UnauthenticatedError
 
     @classmethod
@@ -153,82 +152,6 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
 
         return {"metadata": metadata["extra"]}
 
-    @staticmethod
-    def is_oauth2_enabled() -> bool:
-        """
-        Return if OAuth2 is enabled for GSheets.
-        """
-        return "Google Sheets" in 
current_app.config["DATABASE_OAUTH2_CREDENTIALS"]
-
-    @classmethod
-    def get_oauth2_authorization_uri(cls, state: OAuth2State) -> str:
-        """
-        Return URI for initial OAuth2 request.
-
-        
https://developers.google.com/identity/protocols/oauth2/web-server#creatingclient
-        """
-        config = current_app.config["DATABASE_OAUTH2_CREDENTIALS"]["Google 
Sheets"]
-        baseurl = config.get("BASEURL", 
"https://accounts.google.com/o/oauth2/v2/auth";)
-        redirect_uri = current_app.config.get(
-            "DATABASE_OAUTH2_REDIRECT_URI",
-            state["default_redirect_uri"],
-        )
-
-        params = {
-            "scope": " ".join(SCOPES),
-            "access_type": "offline",
-            "include_granted_scopes": "false",
-            "response_type": "code",
-            "state": encode_oauth2_state(state),
-            "redirect_uri": redirect_uri,
-            "client_id": config["CLIENT_ID"],
-            "prompt": "consent",
-        }
-        return urljoin(baseurl, "?" + urlencode(params))
-
-    @staticmethod
-    def get_oauth2_token(code: str, state: OAuth2State) -> OAuth2TokenResponse:
-        """
-        Exchange authorization code for refresh/access tokens.
-        """
-        config = current_app.config["DATABASE_OAUTH2_CREDENTIALS"]["Google 
Sheets"]
-        redirect_uri = current_app.config.get(
-            "DATABASE_OAUTH2_REDIRECT_URI",
-            state["default_redirect_uri"],
-        )
-
-        response = http.request(
-            "POST",
-            "https://oauth2.googleapis.com/token";,
-            fields={
-                "code": code,
-                "client_id": config["CLIENT_ID"],
-                "client_secret": config["CLIENT_SECRET"],
-                "redirect_uri": redirect_uri,
-                "grant_type": "authorization_code",
-            },
-        )
-        return json.loads(response.data.decode("utf-8"))
-
-    @staticmethod
-    def get_oauth2_fresh_token(refresh_token: str) -> OAuth2TokenResponse:
-        """
-        Refresh an access token that has expired.
-        """
-        config = current_app.config["DATABASE_OAUTH2_CREDENTIALS"]["Google 
Sheets"]
-
-        response = http.request(
-            "POST",
-            "https://oauth2.googleapis.com/token";,
-            fields={
-                "client_id": config["CLIENT_ID"],
-                "client_secret": config["CLIENT_SECRET"],
-                "refresh_token": refresh_token,
-                "grant_type": "refresh_token",
-            },
-        )
-        return json.loads(response.data.decode("utf-8"))
-
     @classmethod
     def build_sqlalchemy_uri(
         cls,
diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py
index a97dd88aef..4b4c2d951b 100644
--- a/superset/db_engine_specs/hive.py
+++ b/superset/db_engine_specs/hive.py
@@ -528,6 +528,7 @@ class HiveEngineSpec(PrestoEngineSpec):
         connect_args: dict[str, Any],
         uri: str,
         username: str | None,
+        access_token: str | None,
     ) -> None:
         """
         Update a configuration dictionary
diff --git a/superset/db_engine_specs/presto.py 
b/superset/db_engine_specs/presto.py
index 0df8d53f4f..ab44edc258 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -717,6 +717,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
         connect_args: dict[str, Any],
         uri: str,
         username: str | None,
+        access_token: str | None,
     ) -> None:
         """
         Update a configuration dictionary
@@ -724,6 +725,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
         :param connect_args: config to be updated
         :param uri: URI string
         :param username: Effective username
+        :param access_token: Personal access token for OAuth2
         :return: None
         """
         url = make_url_safe(uri)
diff --git a/superset/db_engine_specs/trino.py 
b/superset/db_engine_specs/trino.py
index 4513d63c60..4eeac72a7b 100644
--- a/superset/db_engine_specs/trino.py
+++ b/superset/db_engine_specs/trino.py
@@ -112,6 +112,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
         connect_args: dict[str, Any],
         uri: str,
         username: str | None,
+        access_token: str | None,
     ) -> None:
         """
         Update a configuration dictionary
@@ -119,6 +120,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
         :param connect_args: config to be updated
         :param uri: URI string
         :param username: Effective username
+        :param access_token: Personal access token for OAuth2
         :return: None
         """
         url = make_url_safe(uri)
diff --git a/superset/models/core.py b/superset/models/core.py
index cf84b90ac2..eb165e4dfa 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -71,7 +71,7 @@ from superset.extensions import (
 )
 from superset.models.helpers import AuditMixinNullable, ImportExportMixin
 from superset.result_set import SupersetResultSet
-from superset.superset_typing import ResultSetColumnType
+from superset.superset_typing import OAuth2ClientConfig, ResultSetColumnType
 from superset.utils import cache as cache_util, core as utils
 from superset.utils.backports import StrEnum
 from superset.utils.core import get_username
@@ -467,7 +467,12 @@ class Database(
 
         effective_username = self.get_effective_user(sqlalchemy_url)
         access_token = (
-            get_oauth2_access_token(self.id, g.user.id, self.db_engine_spec)
+            get_oauth2_access_token(
+                self.get_oauth2_config(),
+                self.id,
+                g.user.id,
+                self.db_engine_spec,
+            )
             if hasattr(g, "user") and hasattr(g.user, "id")
             else None
         )
@@ -489,6 +494,7 @@ class Database(
                 connect_args,
                 str(sqlalchemy_url),
                 effective_username,
+                access_token,
             )
 
         if connect_args:
@@ -599,7 +605,7 @@ class Database(
                         database=None,
                     )
                 _log_query(sql_)
-                self.db_engine_spec.execute(cursor, sql_, self.id)
+                self.db_engine_spec.execute(cursor, sql_, self)
                 cursor.fetchall()
 
             if mutate_after_split:
@@ -609,10 +615,10 @@ class Database(
                     database=None,
                 )
                 _log_query(last_sql)
-                self.db_engine_spec.execute(cursor, last_sql, self.id)
+                self.db_engine_spec.execute(cursor, last_sql, self)
             else:
                 _log_query(sqls[-1])
-                self.db_engine_spec.execute(cursor, sqls[-1], self.id)
+                self.db_engine_spec.execute(cursor, sqls[-1], self)
 
             data = self.db_engine_spec.fetch_data(cursor)
             result_set = SupersetResultSet(
@@ -983,6 +989,26 @@ class Database(
         sqla_col.key = label_expected
         return sqla_col
 
+    def is_oauth2_enabled(self) -> bool:
+        """
+        Is OAuth2 enabled in the database for authentication?
+
+        Currently this looks for a global config at the DB engine spec level, 
but in the
+        future we want to be allow admins to create custom OAuth2 clients from 
the
+        Superset UI, and assign them to specific databases.
+        """
+        return self.db_engine_spec.is_oauth2_enabled()
+
+    def get_oauth2_config(self) -> OAuth2ClientConfig:
+        """
+        Return OAuth2 client configuration.
+
+        This includes client ID, client secret, scope, redirect URI, 
endpointsm etc.
+        Currently this reads the global DB engine spec config, but in the 
future it
+        should first check if there's a custom client assigned to the database.
+        """
+        return self.db_engine_spec.get_oauth2_config()
+
 
 sqla.event.listen(Database, "after_insert", 
security_manager.database_after_insert)
 sqla.event.listen(Database, "after_update", 
security_manager.database_after_update)
diff --git a/superset/sql_validators/presto_db.py 
b/superset/sql_validators/presto_db.py
index 8c815ad63e..4852f70ee4 100644
--- a/superset/sql_validators/presto_db.py
+++ b/superset/sql_validators/presto_db.py
@@ -73,7 +73,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
         from pyhive.exc import DatabaseError
 
         try:
-            db_engine_spec.execute(cursor, sql, database.id)
+            db_engine_spec.execute(cursor, sql, database)
             polled = cursor.poll()
             while polled:
                 logger.info("polling presto for validation progress")
diff --git a/superset/superset_typing.py b/superset/superset_typing.py
index c71dcea3f1..ba623f5819 100644
--- a/superset/superset_typing.py
+++ b/superset/superset_typing.py
@@ -121,3 +121,52 @@ FlaskResponse = Union[
     tuple[Base, Status, Headers],
     tuple[Response, Status],
 ]
+
+
+class OAuth2ClientConfig(TypedDict):
+    """
+    Configuration for an OAuth2 client.
+    """
+
+    # The client ID and secret.
+    id: str
+    secret: str
+
+    # The scopes requested; this is usually a space separated list of URLs.
+    scope: str
+
+    # The URI where the user is redirected to after authorizing the client; by 
default
+    # this points to `/api/v1/databases/oauth2/`, but it can be overridden by 
the admin.
+    redirect_uri: str
+
+    # The URI used to getting a code.
+    authorization_request_uri: str
+
+    # The URI used when exchaing the code for an access token, or when 
refreshing an
+    # expired access token.
+    token_request_uri: str
+
+
+class OAuth2TokenResponse(TypedDict, total=False):
+    """
+    Type for an OAuth2 response when exchanging or refreshing tokens.
+    """
+
+    access_token: str
+    expires_in: int
+    scope: str
+    token_type: str
+
+    # only present when exchanging code for refresh/access tokens
+    refresh_token: str
+
+
+class OAuth2State(TypedDict):
+    """
+    Type for the state passed during OAuth2.
+    """
+
+    database_id: int
+    user_id: int
+    default_redirect_uri: str
+    tab_id: str
diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py
index 7e80df9599..9cc58a0b7f 100644
--- a/superset/utils/oauth2.py
+++ b/superset/utils/oauth2.py
@@ -26,11 +26,12 @@ from flask import current_app
 from marshmallow import EXCLUDE, fields, post_load, Schema
 
 from superset import db
-from superset.db_engine_specs.base import BaseEngineSpec, OAuth2State
 from superset.exceptions import CreateKeyValueDistributedLockFailedException
+from superset.superset_typing import OAuth2ClientConfig, OAuth2State
 from superset.utils.lock import KeyValueDistributedLock
 
 if TYPE_CHECKING:
+    from superset.db_engine_specs.base import BaseEngineSpec
     from superset.models.core import DatabaseUserOAuth2Tokens
 
 JWT_EXPIRATION = timedelta(minutes=5)
@@ -44,6 +45,7 @@ JWT_EXPIRATION = timedelta(minutes=5)
     max_tries=5,
 )
 def get_oauth2_access_token(
+    config: OAuth2ClientConfig,
     database_id: int,
     user_id: int,
     db_engine_spec: type[BaseEngineSpec],
@@ -73,7 +75,7 @@ def get_oauth2_access_token(
         return token.access_token
 
     if token.refresh_token:
-        return refresh_oauth2_token(database_id, user_id, db_engine_spec, 
token)
+        return refresh_oauth2_token(config, database_id, user_id, 
db_engine_spec, token)
 
     # since the access token is expired and there's no refresh token, delete 
the entry
     db.session.delete(token)
@@ -82,6 +84,7 @@ def get_oauth2_access_token(
 
 
 def refresh_oauth2_token(
+    config: OAuth2ClientConfig,
     database_id: int,
     user_id: int,
     db_engine_spec: type[BaseEngineSpec],
@@ -92,7 +95,10 @@ def refresh_oauth2_token(
         user_id=user_id,
         database_id=database_id,
     ):
-        token_response = 
db_engine_spec.get_oauth2_fresh_token(token.refresh_token)
+        token_response = db_engine_spec.get_oauth2_fresh_token(
+            config,
+            token.refresh_token,
+        )
 
         # store new access token; note that the refresh token might be 
revoked, in which
         # case there would be no access token in the response
diff --git a/tests/unit_tests/db_engine_specs/test_gsheets.py 
b/tests/unit_tests/db_engine_specs/test_gsheets.py
index eb72878a78..89ad88e972 100644
--- a/tests/unit_tests/db_engine_specs/test_gsheets.py
+++ b/tests/unit_tests/db_engine_specs/test_gsheets.py
@@ -451,7 +451,7 @@ def test_is_oauth2_enabled_no_config(mocker: MockFixture) 
-> None:
 
     mocker.patch(
         "superset.db_engine_specs.gsheets.current_app.config",
-        new={"DATABASE_OAUTH2_CREDENTIALS": {}},
+        new={"DATABASE_OAUTH2_CLIENTS": {}},
     )
 
     assert GSheetsEngineSpec.is_oauth2_enabled() is False
@@ -466,7 +466,7 @@ def test_is_oauth2_enabled_config(mocker: MockFixture) -> 
None:
     mocker.patch(
         "superset.db_engine_specs.gsheets.current_app.config",
         new={
-            "DATABASE_OAUTH2_CREDENTIALS": {
+            "DATABASE_OAUTH2_CLIENTS": {
                 "Google Sheets": {
                     "CLIENT_ID": "XXX.apps.googleusercontent.com",
                     "CLIENT_SECRET": "GOCSPX-YYY",
@@ -487,7 +487,7 @@ def test_get_oauth2_authorization_uri(mocker: MockFixture) 
-> None:
     mocker.patch(
         "superset.db_engine_specs.gsheets.current_app.config",
         new={
-            "DATABASE_OAUTH2_CREDENTIALS": {
+            "DATABASE_OAUTH2_CLIENTS": {
                 "Google Sheets": {
                     "CLIENT_ID": "XXX.apps.googleusercontent.com",
                     "CLIENT_SECRET": "GOCSPX-YYY",
@@ -540,7 +540,7 @@ def test_get_oauth2_token(mocker: MockFixture) -> None:
     mocker.patch(
         "superset.db_engine_specs.gsheets.current_app.config",
         new={
-            "DATABASE_OAUTH2_CREDENTIALS": {
+            "DATABASE_OAUTH2_CLIENTS": {
                 "Google Sheets": {
                     "CLIENT_ID": "XXX.apps.googleusercontent.com",
                     "CLIENT_SECRET": "GOCSPX-YYY",
@@ -598,7 +598,7 @@ def test_get_oauth2_fresh_token(mocker: MockFixture) -> 
None:
     mocker.patch(
         "superset.db_engine_specs.gsheets.current_app.config",
         new={
-            "DATABASE_OAUTH2_CREDENTIALS": {
+            "DATABASE_OAUTH2_CLIENTS": {
                 "Google Sheets": {
                     "CLIENT_ID": "XXX.apps.googleusercontent.com",
                     "CLIENT_SECRET": "GOCSPX-YYY",


Reply via email to