This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch sip-85-b
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 8d49afd12e233b27ba9486ac0996318dfb21a5c8
Author: Beto Dealmeida <[email protected]>
AuthorDate: Wed Apr 3 10:12:51 2024 -0400

    WIP
---
 superset/config.py                               |  20 ++-
 superset/connectors/sqla/utils.py                |   2 +-
 superset/db_engine_specs/README.md               |   8 +-
 superset/db_engine_specs/base.py                 | 174 ++++++++++++++++-------
 superset/db_engine_specs/gsheets.py              |  89 +-----------
 superset/db_engine_specs/hive.py                 |   1 +
 superset/db_engine_specs/presto.py               |   2 +
 superset/db_engine_specs/trino.py                |   2 +
 superset/models/core.py                          |  36 ++++-
 superset/sql_validators/presto_db.py             |   2 +-
 superset/superset_typing.py                      |  49 +++++++
 superset/utils/oauth2.py                         |  12 +-
 tests/unit_tests/db_engine_specs/test_gsheets.py |  10 +-
 13 files changed, 250 insertions(+), 157 deletions(-)

diff --git a/superset/config.py b/superset/config.py
index b6dbfc9fee..bb5b04b232 100644
--- a/superset/config.py
+++ b/superset/config.py
@@ -1409,12 +1409,20 @@ TEST_DATABASE_CONNECTION_TIMEOUT = timedelta(seconds=30)
 
 # Details needed for databases that allows user to authenticate using personal
 # OAuth2 tokens. See https://github.com/apache/superset/issues/20300 for more
-# information
-DATABASE_OAUTH2_CREDENTIALS: dict[str, dict[str, Any]] = {
+# information. The scope and URIs are optional.
+DATABASE_OAUTH2_CLIENTS: dict[str, dict[str, Any]] = {
     # "Google Sheets": {
-    #    "CLIENT_ID": "XXX.apps.googleusercontent.com",
-    #    "CLIENT_SECRET": "GOCSPX-YYY",
-    #    "BASEURL": "https://accounts.google.com/o/oauth2/v2/auth";,
+    #     "id": "XXX.apps.googleusercontent.com",
+    #     "secret": "GOCSPX-YYY",
+    #     "scope": " ".join(
+    #         [
+    #             "https://www.googleapis.com/auth/drive.readonly";,
+    #             "https://www.googleapis.com/auth/spreadsheets";,
+    #             "https://spreadsheets.google.com/feeds";,
+    #         ]
+    #     ),
+    #     "authorization_request_uri": 
"https://accounts.google.com/o/oauth2/v2/auth";,
+    #     "token_request_uri": "https://oauth2.googleapis.com/token";,
     # },
 }
 # OAuth2 state is encoded in a JWT using the alogorithm below.
@@ -1425,6 +1433,8 @@ DATABASE_OAUTH2_JWT_ALGORITHM = "HS256"
 # applications. In that case, the proxy can forward the request to the correct 
instance
 # by looking at the `default_redirect_uri` attribute in the OAuth2 state 
object.
 # DATABASE_OAUTH2_REDIRECT_URI = 
"http://localhost:8088/api/v1/database/oauth2/";
+# Timeout when fetching access and refresh tokens.
+DATABASE_OAUTH2_TIMEOUT = timedelta(seconds=30)
 
 # Enable/disable CSP warning
 CONTENT_SECURITY_POLICY_WARNING = True
diff --git a/superset/connectors/sqla/utils.py 
b/superset/connectors/sqla/utils.py
index d0922e40f3..4bc11aee42 100644
--- a/superset/connectors/sqla/utils.py
+++ b/superset/connectors/sqla/utils.py
@@ -145,7 +145,7 @@ def get_columns_description(
             cursor = conn.cursor()
             query = database.apply_limit_to_sql(query, limit=1)
             cursor.execute(query)
-            db_engine_spec.execute(cursor, query, database.id)
+            db_engine_spec.execute(cursor, query, database)
             result = db_engine_spec.fetch_data(cursor, limit=1)
             result_set = SupersetResultSet(result, cursor.description, 
db_engine_spec)
             return result_set.columns
diff --git a/superset/db_engine_specs/README.md 
b/superset/db_engine_specs/README.md
index 0be1f29146..175840c608 100644
--- a/superset/db_engine_specs/README.md
+++ b/superset/db_engine_specs/README.md
@@ -582,17 +582,17 @@ from flask import current_app
 class GSheetsEngineSpec(ShillelaghEngineSpec):
     @staticmethod
     def is_oauth2_enabled() -> bool:
-        return "Google Sheets" in 
current_app.config["DATABASE_OAUTH2_CREDENTIALS"]
+        return "Google Sheets" in current_app.config["DATABASE_OAUTH2_CLIENTS"]
 ```
 
 Where the configuration for OAuth2 would look like this:
 
 ```python
 # superset_config.py
-DATABASE_OAUTH2_CREDENTIALS = {
+DATABASE_OAUTH2_CLIENTS = {
     "Google Sheets": {
-        "CLIENT_ID": "XXX.apps.googleusercontent.com",
-        "CLIENT_SECRET": "GOCSPX-YYY",
+        "id": "XXX.apps.googleusercontent.com",
+        "secret": "GOCSPX-YYY",
     },
 }
 DATABASE_OAUTH2_JWT_ALGORITHM = "HS256"
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 12797fc6a3..8fdffe2095 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -33,9 +33,11 @@ from typing import (
     TypedDict,
     Union,
 )
+from urllib.parse import urlencode, urljoin
 from uuid import uuid4
 
 import pandas as pd
+import requests
 import sqlparse
 from apispec import APISpec
 from apispec.ext.marshmallow import MarshmallowPlugin
@@ -60,13 +62,20 @@ from superset import security_manager, sql_parse
 from superset.constants import TimeGrain as TimeGrainConstants
 from superset.databases.utils import make_url_safe
 from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
-from superset.exceptions import OAuth2Error, OAuth2RedirectError
+from superset.exceptions import OAuth2RedirectError
 from superset.sql_parse import ParsedQuery, SQLScript, Table
-from superset.superset_typing import ResultSetColumnType, SQLAColumnType
+from superset.superset_typing import (
+    OAuth2ClientConfig,
+    OAuth2State,
+    OAuth2TokenResponse,
+    ResultSetColumnType,
+    SQLAColumnType,
+)
 from superset.utils import core as utils
 from superset.utils.core import ColumnSpec, GenericDataType
 from superset.utils.hashing import md5_sha_from_str
 from superset.utils.network import is_hostname_valid, is_port_open
+from superset.utils.oauth2 import encode_oauth2_state
 
 if TYPE_CHECKING:
     from superset.connectors.sqla.models import TableColumn
@@ -173,31 +182,6 @@ class MetricType(TypedDict, total=False):
     extra: str | None
 
 
-class OAuth2TokenResponse(TypedDict, total=False):
-    """
-    Type for an OAuth2 response when exchanging or refreshing tokens.
-    """
-
-    access_token: str
-    expires_in: int
-    scope: str
-    token_type: str
-
-    # only present when exchanging code for refresh/access tokens
-    refresh_token: str
-
-
-class OAuth2State(TypedDict):
-    """
-    Type for the state passed during OAuth2.
-    """
-
-    database_id: int
-    user_id: int
-    default_redirect_uri: str
-    tab_id: str
-
-
 class BaseEngineSpec:  # pylint: disable=too-many-public-methods
     """Abstract class for database engine specific configurations
 
@@ -425,15 +409,25 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
     # Can the catalog be changed on a per-query basis?
     supports_dynamic_catalog = False
 
+    # Does the engine supports OAuth 2.0? This requires logic to be added to 
one of the
+    # the user impersonation methods to handle personal tokens.
+    supports_oauth2 = False
+    oauth2_scope = ""
+    oauth2_authorization_request_uri = ""
+    oauth2_token_request_uri = ""
+
     # Driver-specific exception that should be mapped to OAuth2RedirectError
-    oauth2_exception = OAuth2RedirectError
+    oauth2_exception: Exception = OAuth2RedirectError
 
-    @staticmethod
-    def is_oauth2_enabled() -> bool:
-        return False
+    @classmethod
+    def is_oauth2_enabled(cls) -> bool:
+        return (
+            cls.supports_oauth2
+            and cls.engine_name in 
current_app.config["DATABASE_OAUTH2_CLIENTS"]
+        )
 
     @classmethod
-    def start_oauth2_dance(cls, database_id: int) -> None:
+    def start_oauth2_dance(cls, database: "Database") -> None:
         """
         Start the OAuth2 dance.
 
@@ -446,10 +440,6 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         """
         tab_id = str(uuid4())
         default_redirect_uri = url_for("DatabaseRestApi.oauth2", 
_external=True)
-        redirect_uri = current_app.config.get(
-            "DATABASE_OAUTH2_REDIRECT_URI",
-            default_redirect_uri,
-        )
 
         # The state is passed to the OAuth2 provider, and sent back to 
Superset after
         # the user authorizes the access. The redirect endpoint in Superset 
can then
@@ -457,7 +447,7 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         # belongs to.
         state: OAuth2State = {
             # Database ID and user ID are the primary key associated with the 
token.
-            "database_id": database_id,
+            "database_id": database.id,
             "user_id": g.user.id,
             # In multi-instance deployments there might be a single proxy 
handling
             # redirects, with a custom `DATABASE_OAUTH2_REDIRECT_URI`. Since 
the OAuth2
@@ -473,30 +463,112 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
             # message.
             "tab_id": tab_id,
         }
-        oauth_url = cls.get_oauth2_authorization_uri(state)
+        config = database.get_oauth2_config()
+        oauth_url = cls.get_oauth2_authorization_uri(config, state)
 
-        raise OAuth2RedirectError(oauth_url, tab_id, redirect_uri)
+        raise OAuth2RedirectError(oauth_url, tab_id, default_redirect_uri)
 
-    @staticmethod
-    def get_oauth2_authorization_uri(state: OAuth2State) -> str:
+    @classmethod
+    def get_oauth2_config(cls) -> OAuth2ClientConfig:
+        """
+        Build the OAuth2 client config.
+
+        Currently this is built based on the `config.py` file, since clients 
are defined
+        at the DB engine spec level. In the future we'll allow admins to 
create custom
+        OAuth2 clients and assign them to specific databases.
+        """
+        redirect_uri = current_app.config.get(
+            "DATABASE_OAUTH2_REDIRECT_URI",
+            url_for("DatabaseRestApi.oauth2", _external=True),
+        )
+
+        oauth2_config = current_app.config["DATABASE_OAUTH2_CLIENTS"]
+        try:
+            db_engine_spec_config = oauth2_config[cls.engine_name]
+        except KeyError:
+            raise Exception("TODO")
+
+        config: OAuth2ClientConfig = {
+            "id": db_engine_spec_config["id"],
+            "secret": db_engine_spec_config["secret"],
+            "scope": db_engine_spec_config.get("scope") or cls.oauth2_scope,
+            "redirect_uri": redirect_uri,
+            "authorization_request_uri": db_engine_spec_config.get(
+                "authorization_request_uri"
+            )
+            or cls.oauth2_authorization_request_uri,
+            "token_request_uri": db_engine_spec_config.get("token_request_uri")
+            or cls.oauth2_token_request_uri,
+        }
+
+        return config
+
+    @classmethod
+    def get_oauth2_authorization_uri(
+        cls,
+        config: OAuth2ClientConfig,
+        state: OAuth2State,
+    ) -> str:
         """
         Return URI for initial OAuth2 request.
         """
-        raise OAuth2Error("Subclasses must implement 
`get_oauth2_authorization_uri`")
+        uri = config["authorization_request_uri"]
+        params = {
+            "scope": config["scope"],
+            "access_type": "offline",
+            "include_granted_scopes": "false",
+            "response_type": "code",
+            "state": encode_oauth2_state(state),
+            "redirect_uri": config["redirect_uri"],
+            "client_id": config["id"],
+            "prompt": "consent",
+        }
+        return urljoin(uri, "?" + urlencode(params))
 
-    @staticmethod
-    def get_oauth2_token(code: str, state: OAuth2State) -> OAuth2TokenResponse:
+    @classmethod
+    def get_oauth2_token(
+        config: OAuth2ClientConfig,
+        code: str,
+    ) -> OAuth2TokenResponse:
         """
         Exchange authorization code for refresh/access tokens.
         """
-        raise OAuth2Error("Subclasses must implement `get_oauth2_token`")
+        timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
+        uri = config["token_request_uri"]
+        response = requests.post(
+            uri,
+            fields={
+                "code": code,
+                "client_id": config["id"],
+                "client_secret": config["secret"],
+                "redirect_uri": config["redirect_uri"],
+                "grant_type": "authorization_code",
+            },
+            timeout=timeout,
+        )
+        return response.json()
 
     @staticmethod
-    def get_oauth2_fresh_token(refresh_token: str) -> OAuth2TokenResponse:
+    def get_oauth2_fresh_token(
+        config: OAuth2ClientConfig,
+        refresh_token: str,
+    ) -> OAuth2TokenResponse:
         """
         Refresh an access token that has expired.
         """
-        raise OAuth2Error("Subclasses must implement `get_oauth2_fresh_token`")
+        timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
+        uri = config["token_request_uri"]
+        response = requests.post(
+            uri,
+            fields={
+                "client_id": config["id"],
+                "client_secret": config["secret"],
+                "refresh_token": refresh_token,
+                "grant_type": "refresh_token",
+            },
+            timeout=timeout,
+        )
+        return response.json()
 
     @classmethod
     def get_allows_alias_in_select(
@@ -1667,6 +1739,7 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         connect_args: dict[str, Any],
         uri: str,
         username: str | None,
+        access_token: str | None,
     ) -> None:
         """
         Update a configuration dictionary
@@ -1675,6 +1748,7 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         :param connect_args: config to be updated
         :param uri: URI
         :param username: Effective username
+        :param access_token: Personal access token for OAuth2
         :return: None
         """
 
@@ -1683,7 +1757,7 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         cls,
         cursor: Any,
         query: str,
-        database_id: int,
+        database: "Database",
         **kwargs: Any,
     ) -> None:
         """
@@ -1703,8 +1777,8 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         try:
             cursor.execute(query)
         except cls.oauth2_exception as ex:
-            if cls.is_oauth2_enabled() and g.user:
-                cls.start_oauth2_dance(database_id)
+            if database.is_oauth2_enabled() and g.user:
+                cls.start_oauth2_dance(database)
             raise cls.get_dbapi_mapped_exception(ex) from ex
         except Exception as ex:
             raise cls.get_dbapi_mapped_exception(ex) from ex
diff --git a/superset/db_engine_specs/gsheets.py 
b/superset/db_engine_specs/gsheets.py
index 28e95811d4..f74982fe4b 100644
--- a/superset/db_engine_specs/gsheets.py
+++ b/superset/db_engine_specs/gsheets.py
@@ -23,13 +23,11 @@ import logging
 import re
 from re import Pattern
 from typing import Any, TYPE_CHECKING, TypedDict
-from urllib.parse import urlencode, urljoin
 
 import pandas as pd
-import urllib3
 from apispec import APISpec
 from apispec.ext.marshmallow import MarshmallowPlugin
-from flask import current_app, g
+from flask import g
 from flask_babel import gettext as __
 from marshmallow import fields, Schema
 from marshmallow.exceptions import ValidationError
@@ -42,11 +40,9 @@ from sqlalchemy.engine.url import URL
 from superset import db, security_manager
 from superset.constants import PASSWORD_MASK
 from superset.databases.schemas import encrypted_field_properties, 
EncryptedString
-from superset.db_engine_specs.base import OAuth2State, OAuth2TokenResponse
 from superset.db_engine_specs.shillelagh import ShillelaghEngineSpec
 from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
 from superset.exceptions import SupersetException
-from superset.utils.oauth2 import encode_oauth2_state
 
 if TYPE_CHECKING:
     from superset.models.core import Database
@@ -62,7 +58,6 @@ EXAMPLE_GSHEETS_URL = (
 SYNTAX_ERROR_REGEX = re.compile('SQLError: near "(?P<server_error>.*?)": 
syntax error')
 
 ma_plugin = MarshmallowPlugin()
-http = urllib3.PoolManager()
 
 
 class GSheetsParametersSchema(Schema):
@@ -111,7 +106,11 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
 
     supports_file_upload = True
 
-    # exception raised by shillelagh that should trigger OAuth2
+    # OAuth 2.0
+    supports_oauth2 = True
+    oauth2_scope = " ".join(SCOPES)
+    oauth2_authorization_request_uri = 
"https://accounts.google.com/o/oauth2/v2/auth";
+    oauth2_token_request_uri = "https://oauth2.googleapis.com/token";
     oauth2_exception = UnauthenticatedError
 
     @classmethod
@@ -153,82 +152,6 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
 
         return {"metadata": metadata["extra"]}
 
-    @staticmethod
-    def is_oauth2_enabled() -> bool:
-        """
-        Return if OAuth2 is enabled for GSheets.
-        """
-        return "Google Sheets" in 
current_app.config["DATABASE_OAUTH2_CREDENTIALS"]
-
-    @classmethod
-    def get_oauth2_authorization_uri(cls, state: OAuth2State) -> str:
-        """
-        Return URI for initial OAuth2 request.
-
-        
https://developers.google.com/identity/protocols/oauth2/web-server#creatingclient
-        """
-        config = current_app.config["DATABASE_OAUTH2_CREDENTIALS"]["Google 
Sheets"]
-        baseurl = config.get("BASEURL", 
"https://accounts.google.com/o/oauth2/v2/auth";)
-        redirect_uri = current_app.config.get(
-            "DATABASE_OAUTH2_REDIRECT_URI",
-            state["default_redirect_uri"],
-        )
-
-        params = {
-            "scope": " ".join(SCOPES),
-            "access_type": "offline",
-            "include_granted_scopes": "false",
-            "response_type": "code",
-            "state": encode_oauth2_state(state),
-            "redirect_uri": redirect_uri,
-            "client_id": config["CLIENT_ID"],
-            "prompt": "consent",
-        }
-        return urljoin(baseurl, "?" + urlencode(params))
-
-    @staticmethod
-    def get_oauth2_token(code: str, state: OAuth2State) -> OAuth2TokenResponse:
-        """
-        Exchange authorization code for refresh/access tokens.
-        """
-        config = current_app.config["DATABASE_OAUTH2_CREDENTIALS"]["Google 
Sheets"]
-        redirect_uri = current_app.config.get(
-            "DATABASE_OAUTH2_REDIRECT_URI",
-            state["default_redirect_uri"],
-        )
-
-        response = http.request(
-            "POST",
-            "https://oauth2.googleapis.com/token";,
-            fields={
-                "code": code,
-                "client_id": config["CLIENT_ID"],
-                "client_secret": config["CLIENT_SECRET"],
-                "redirect_uri": redirect_uri,
-                "grant_type": "authorization_code",
-            },
-        )
-        return json.loads(response.data.decode("utf-8"))
-
-    @staticmethod
-    def get_oauth2_fresh_token(refresh_token: str) -> OAuth2TokenResponse:
-        """
-        Refresh an access token that has expired.
-        """
-        config = current_app.config["DATABASE_OAUTH2_CREDENTIALS"]["Google 
Sheets"]
-
-        response = http.request(
-            "POST",
-            "https://oauth2.googleapis.com/token";,
-            fields={
-                "client_id": config["CLIENT_ID"],
-                "client_secret": config["CLIENT_SECRET"],
-                "refresh_token": refresh_token,
-                "grant_type": "refresh_token",
-            },
-        )
-        return json.loads(response.data.decode("utf-8"))
-
     @classmethod
     def build_sqlalchemy_uri(
         cls,
diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py
index a97dd88aef..4b4c2d951b 100644
--- a/superset/db_engine_specs/hive.py
+++ b/superset/db_engine_specs/hive.py
@@ -528,6 +528,7 @@ class HiveEngineSpec(PrestoEngineSpec):
         connect_args: dict[str, Any],
         uri: str,
         username: str | None,
+        access_token: str | None,
     ) -> None:
         """
         Update a configuration dictionary
diff --git a/superset/db_engine_specs/presto.py 
b/superset/db_engine_specs/presto.py
index 0df8d53f4f..ab44edc258 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -717,6 +717,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
         connect_args: dict[str, Any],
         uri: str,
         username: str | None,
+        access_token: str | None,
     ) -> None:
         """
         Update a configuration dictionary
@@ -724,6 +725,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
         :param connect_args: config to be updated
         :param uri: URI string
         :param username: Effective username
+        :param access_token: Personal access token for OAuth2
         :return: None
         """
         url = make_url_safe(uri)
diff --git a/superset/db_engine_specs/trino.py 
b/superset/db_engine_specs/trino.py
index 4513d63c60..4eeac72a7b 100644
--- a/superset/db_engine_specs/trino.py
+++ b/superset/db_engine_specs/trino.py
@@ -112,6 +112,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
         connect_args: dict[str, Any],
         uri: str,
         username: str | None,
+        access_token: str | None,
     ) -> None:
         """
         Update a configuration dictionary
@@ -119,6 +120,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
         :param connect_args: config to be updated
         :param uri: URI string
         :param username: Effective username
+        :param access_token: Personal access token for OAuth2
         :return: None
         """
         url = make_url_safe(uri)
diff --git a/superset/models/core.py b/superset/models/core.py
index cf84b90ac2..eb165e4dfa 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -71,7 +71,7 @@ from superset.extensions import (
 )
 from superset.models.helpers import AuditMixinNullable, ImportExportMixin
 from superset.result_set import SupersetResultSet
-from superset.superset_typing import ResultSetColumnType
+from superset.superset_typing import OAuth2ClientConfig, ResultSetColumnType
 from superset.utils import cache as cache_util, core as utils
 from superset.utils.backports import StrEnum
 from superset.utils.core import get_username
@@ -467,7 +467,12 @@ class Database(
 
         effective_username = self.get_effective_user(sqlalchemy_url)
         access_token = (
-            get_oauth2_access_token(self.id, g.user.id, self.db_engine_spec)
+            get_oauth2_access_token(
+                self.get_oauth2_config(),
+                self.id,
+                g.user.id,
+                self.db_engine_spec,
+            )
             if hasattr(g, "user") and hasattr(g.user, "id")
             else None
         )
@@ -489,6 +494,7 @@ class Database(
                 connect_args,
                 str(sqlalchemy_url),
                 effective_username,
+                access_token,
             )
 
         if connect_args:
@@ -599,7 +605,7 @@ class Database(
                         database=None,
                     )
                 _log_query(sql_)
-                self.db_engine_spec.execute(cursor, sql_, self.id)
+                self.db_engine_spec.execute(cursor, sql_, self)
                 cursor.fetchall()
 
             if mutate_after_split:
@@ -609,10 +615,10 @@ class Database(
                     database=None,
                 )
                 _log_query(last_sql)
-                self.db_engine_spec.execute(cursor, last_sql, self.id)
+                self.db_engine_spec.execute(cursor, last_sql, self)
             else:
                 _log_query(sqls[-1])
-                self.db_engine_spec.execute(cursor, sqls[-1], self.id)
+                self.db_engine_spec.execute(cursor, sqls[-1], self)
 
             data = self.db_engine_spec.fetch_data(cursor)
             result_set = SupersetResultSet(
@@ -983,6 +989,26 @@ class Database(
         sqla_col.key = label_expected
         return sqla_col
 
+    def is_oauth2_enabled(self) -> bool:
+        """
+        Is OAuth2 enabled in the database for authentication?
+
+        Currently this looks for a global config at the DB engine spec level, 
but in the
+        future we want to be allow admins to create custom OAuth2 clients from 
the
+        Superset UI, and assign them to specific databases.
+        """
+        return self.db_engine_spec.is_oauth2_enabled()
+
+    def get_oauth2_config(self) -> OAuth2ClientConfig:
+        """
+        Return OAuth2 client configuration.
+
+        This includes client ID, client secret, scope, redirect URI, 
endpointsm etc.
+        Currently this reads the global DB engine spec config, but in the 
future it
+        should first check if there's a custom client assigned to the database.
+        """
+        return self.db_engine_spec.get_oauth2_config()
+
 
 sqla.event.listen(Database, "after_insert", 
security_manager.database_after_insert)
 sqla.event.listen(Database, "after_update", 
security_manager.database_after_update)
diff --git a/superset/sql_validators/presto_db.py 
b/superset/sql_validators/presto_db.py
index 8c815ad63e..4852f70ee4 100644
--- a/superset/sql_validators/presto_db.py
+++ b/superset/sql_validators/presto_db.py
@@ -73,7 +73,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
         from pyhive.exc import DatabaseError
 
         try:
-            db_engine_spec.execute(cursor, sql, database.id)
+            db_engine_spec.execute(cursor, sql, database)
             polled = cursor.poll()
             while polled:
                 logger.info("polling presto for validation progress")
diff --git a/superset/superset_typing.py b/superset/superset_typing.py
index c71dcea3f1..ba623f5819 100644
--- a/superset/superset_typing.py
+++ b/superset/superset_typing.py
@@ -121,3 +121,52 @@ FlaskResponse = Union[
     tuple[Base, Status, Headers],
     tuple[Response, Status],
 ]
+
+
+class OAuth2ClientConfig(TypedDict):
+    """
+    Configuration for an OAuth2 client.
+    """
+
+    # The client ID and secret.
+    id: str
+    secret: str
+
+    # The scopes requested; this is usually a space separated list of URLs.
+    scope: str
+
+    # The URI where the user is redirected to after authorizing the client; by 
default
+    # this points to `/api/v1/databases/oauth2/`, but it can be overridden by 
the admin.
+    redirect_uri: str
+
+    # The URI used to getting a code.
+    authorization_request_uri: str
+
+    # The URI used when exchaing the code for an access token, or when 
refreshing an
+    # expired access token.
+    token_request_uri: str
+
+
+class OAuth2TokenResponse(TypedDict, total=False):
+    """
+    Type for an OAuth2 response when exchanging or refreshing tokens.
+    """
+
+    access_token: str
+    expires_in: int
+    scope: str
+    token_type: str
+
+    # only present when exchanging code for refresh/access tokens
+    refresh_token: str
+
+
+class OAuth2State(TypedDict):
+    """
+    Type for the state passed during OAuth2.
+    """
+
+    database_id: int
+    user_id: int
+    default_redirect_uri: str
+    tab_id: str
diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py
index 7e80df9599..9cc58a0b7f 100644
--- a/superset/utils/oauth2.py
+++ b/superset/utils/oauth2.py
@@ -26,11 +26,12 @@ from flask import current_app
 from marshmallow import EXCLUDE, fields, post_load, Schema
 
 from superset import db
-from superset.db_engine_specs.base import BaseEngineSpec, OAuth2State
 from superset.exceptions import CreateKeyValueDistributedLockFailedException
+from superset.superset_typing import OAuth2ClientConfig, OAuth2State
 from superset.utils.lock import KeyValueDistributedLock
 
 if TYPE_CHECKING:
+    from superset.db_engine_specs.base import BaseEngineSpec
     from superset.models.core import DatabaseUserOAuth2Tokens
 
 JWT_EXPIRATION = timedelta(minutes=5)
@@ -44,6 +45,7 @@ JWT_EXPIRATION = timedelta(minutes=5)
     max_tries=5,
 )
 def get_oauth2_access_token(
+    config: OAuth2ClientConfig,
     database_id: int,
     user_id: int,
     db_engine_spec: type[BaseEngineSpec],
@@ -73,7 +75,7 @@ def get_oauth2_access_token(
         return token.access_token
 
     if token.refresh_token:
-        return refresh_oauth2_token(database_id, user_id, db_engine_spec, 
token)
+        return refresh_oauth2_token(config, database_id, user_id, 
db_engine_spec, token)
 
     # since the access token is expired and there's no refresh token, delete 
the entry
     db.session.delete(token)
@@ -82,6 +84,7 @@ def get_oauth2_access_token(
 
 
 def refresh_oauth2_token(
+    config: OAuth2ClientConfig,
     database_id: int,
     user_id: int,
     db_engine_spec: type[BaseEngineSpec],
@@ -92,7 +95,10 @@ def refresh_oauth2_token(
         user_id=user_id,
         database_id=database_id,
     ):
-        token_response = 
db_engine_spec.get_oauth2_fresh_token(token.refresh_token)
+        token_response = db_engine_spec.get_oauth2_fresh_token(
+            config,
+            token.refresh_token,
+        )
 
         # store new access token; note that the refresh token might be 
revoked, in which
         # case there would be no access token in the response
diff --git a/tests/unit_tests/db_engine_specs/test_gsheets.py 
b/tests/unit_tests/db_engine_specs/test_gsheets.py
index eb72878a78..89ad88e972 100644
--- a/tests/unit_tests/db_engine_specs/test_gsheets.py
+++ b/tests/unit_tests/db_engine_specs/test_gsheets.py
@@ -451,7 +451,7 @@ def test_is_oauth2_enabled_no_config(mocker: MockFixture) 
-> None:
 
     mocker.patch(
         "superset.db_engine_specs.gsheets.current_app.config",
-        new={"DATABASE_OAUTH2_CREDENTIALS": {}},
+        new={"DATABASE_OAUTH2_CLIENTS": {}},
     )
 
     assert GSheetsEngineSpec.is_oauth2_enabled() is False
@@ -466,7 +466,7 @@ def test_is_oauth2_enabled_config(mocker: MockFixture) -> 
None:
     mocker.patch(
         "superset.db_engine_specs.gsheets.current_app.config",
         new={
-            "DATABASE_OAUTH2_CREDENTIALS": {
+            "DATABASE_OAUTH2_CLIENTS": {
                 "Google Sheets": {
                     "CLIENT_ID": "XXX.apps.googleusercontent.com",
                     "CLIENT_SECRET": "GOCSPX-YYY",
@@ -487,7 +487,7 @@ def test_get_oauth2_authorization_uri(mocker: MockFixture) 
-> None:
     mocker.patch(
         "superset.db_engine_specs.gsheets.current_app.config",
         new={
-            "DATABASE_OAUTH2_CREDENTIALS": {
+            "DATABASE_OAUTH2_CLIENTS": {
                 "Google Sheets": {
                     "CLIENT_ID": "XXX.apps.googleusercontent.com",
                     "CLIENT_SECRET": "GOCSPX-YYY",
@@ -540,7 +540,7 @@ def test_get_oauth2_token(mocker: MockFixture) -> None:
     mocker.patch(
         "superset.db_engine_specs.gsheets.current_app.config",
         new={
-            "DATABASE_OAUTH2_CREDENTIALS": {
+            "DATABASE_OAUTH2_CLIENTS": {
                 "Google Sheets": {
                     "CLIENT_ID": "XXX.apps.googleusercontent.com",
                     "CLIENT_SECRET": "GOCSPX-YYY",
@@ -598,7 +598,7 @@ def test_get_oauth2_fresh_token(mocker: MockFixture) -> 
None:
     mocker.patch(
         "superset.db_engine_specs.gsheets.current_app.config",
         new={
-            "DATABASE_OAUTH2_CREDENTIALS": {
+            "DATABASE_OAUTH2_CLIENTS": {
                 "Google Sheets": {
                     "CLIENT_ID": "XXX.apps.googleusercontent.com",
                     "CLIENT_SECRET": "GOCSPX-YYY",

Reply via email to