This is an automated email from the ASF dual-hosted git repository. beto pushed a commit to branch db-oauth2-client-info in repository https://gitbox.apache.org/repos/asf/superset.git
commit e4998d823a1213c556096ac0f6597fa33be70b6d Author: Beto Dealmeida <[email protected]> AuthorDate: Wed May 8 10:18:54 2024 -0400 WIP --- superset/commands/database/update.py | 4 ++ superset/databases/api.py | 1 - superset/db_engine_specs/base.py | 18 +++-- superset/db_engine_specs/snowflake.py | 128 +++++++++++++++++++++++++++++++--- superset/models/core.py | 14 ++-- superset/sql_lab.py | 9 +++ superset/utils/oauth2.py | 2 +- 7 files changed, 152 insertions(+), 24 deletions(-) diff --git a/superset/commands/database/update.py b/superset/commands/database/update.py index 5e0968954c..31c1a1e7cf 100644 --- a/superset/commands/database/update.py +++ b/superset/commands/database/update.py @@ -133,6 +133,10 @@ class UpdateDatabaseCommand(BaseCommand): try: schemas = database.get_all_schema_names(ssh_tunnel=ssh_tunnel) except Exception as ex: + # XXX conditional + if 1: + db.session.commit() + database.db_engine_spec.start_oauth2_dance(database) db.session.rollback() raise DatabaseConnectionFailedError() from ex diff --git a/superset/databases/api.py b/superset/databases/api.py index a77019123b..2fca84a357 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -453,7 +453,6 @@ class DatabaseRestApi(BaseSupersetModelRestApi): @expose("/<int:pk>", methods=("PUT",)) @protect() - @safe @statsd_metrics @event_logger.log_this_with_context( action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.put", diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 3cc1315129..65a4cfbd64 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -131,7 +131,9 @@ builtin_time_grains: dict[str | None, str] = { } -class TimestampExpression(ColumnClause): # pylint: disable=abstract-method, too-many-ancestors +class TimestampExpression( + ColumnClause +): # pylint: disable=abstract-method, too-many-ancestors def __init__(self, expr: str, col: ColumnClause, **kwargs: Any) -> None: """Sqlalchemy class that can be used to render native column elements respecting engine-specific quoting rules as part of a string-based expression. @@ -389,9 +391,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods max_column_name_length: int | None = None try_remove_schema_from_table_name = True # pylint: disable=invalid-name run_multiple_statements_as_one = False - custom_errors: dict[ - Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]] - ] = {} + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = ( + {} + ) # Whether the engine supports file uploads # if True, database will be listed as option in the upload file form @@ -1597,9 +1599,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods @classmethod def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[Any]: return [ - literal_column(query_as) - if (query_as := c.get("query_as")) - else column(c["column_name"]) + ( + literal_column(query_as) + if (query_as := c.get("query_as")) + else column(c["column_name"]) + ) for c in cols ] diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py index 83d382cda1..baeb4f4762 100644 --- a/superset/db_engine_specs/snowflake.py +++ b/superset/db_engine_specs/snowflake.py @@ -14,6 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +from __future__ import annotations + import json import logging import re @@ -22,6 +25,7 @@ from re import Pattern from typing import Any, Optional, TYPE_CHECKING, TypedDict from urllib import parse +import requests from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin from cryptography.hazmat.backends import default_backend @@ -29,16 +33,20 @@ from cryptography.hazmat.primitives import serialization from flask import current_app from flask_babel import gettext as __ from marshmallow import fields, Schema +from requests.auth import HTTPBasicAuth from sqlalchemy import types from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL +from sqlalchemy.exc import ProgrammingError +from superset import security_manager from superset.constants import TimeGrain, USER_AGENT from superset.databases.utils import make_url_safe from superset.db_engine_specs.base import BaseEngineSpec, BasicPropertiesType from superset.db_engine_specs.postgres import PostgresBaseEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.models.sql_lab import Query +from superset.superset_typing import OAuth2ClientConfig, OAuth2TokenResponse if TYPE_CHECKING: from superset.models.core import Database @@ -57,12 +65,12 @@ logger = logging.getLogger(__name__) class SnowflakeParametersSchema(Schema): - username = fields.Str(required=True) - password = fields.Str(required=True) + username = fields.Str(required=False) + password = fields.Str(required=False) account = fields.Str(required=True) database = fields.Str(required=True) - role = fields.Str(required=True) - warehouse = fields.Str(required=True) + role = fields.Str(required=False) + warehouse = fields.Str(required=False) class SnowflakeParametersType(TypedDict): @@ -87,6 +95,11 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): supports_dynamic_schema = True supports_catalog = True + supports_oauth2 = True + oauth2_scope = "refresh_token session:role:SYSADMIN" + oauth2_authorization_request_uri = None + oauth2_token_request_uri = None + _time_grain_expressions = { None: "{col}", TimeGrain.SECOND: "DATE_TRUNC('SECOND', {col})", @@ -123,17 +136,29 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): ), } - @staticmethod - def get_extra_params(database: "Database") -> dict[str, Any]: + @classmethod + def get_extra_params(cls, database: "Database") -> dict[str, Any]: """ Add a user agent to be used in the requests. """ extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database) engine_params: dict[str, Any] = extra.setdefault("engine_params", {}) connect_args: dict[str, Any] = engine_params.setdefault("connect_args", {}) - connect_args.setdefault("application", USER_AGENT) + # populate OAuth2 URLs if not set, since they can be inferred from the account + if oauth2_client_info := extra.get("oauth2_client_info"): + account = database.url_object.host + oauth2_client_info.setdefault( + "authorization_request_uri", + f"https://{account}.snowflakecomputing.com/oauth/authorize", + ) + oauth2_client_info.setdefault( + "token_request_uri", + f"https://{account}.snowflakecomputing.com/oauth/token-request", + ) + oauth2_client_info.setdefault("scope", cls.oauth2_scope) + return extra @classmethod @@ -303,11 +328,9 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): ) -> list[SupersetError]: errors: list[SupersetError] = [] required = { - "warehouse", "username", "database", "account", - "role", "password", } parameters = properties.get("parameters", {}) @@ -391,3 +414,90 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): f"must be listed in 'ALLOWED_EXTRA_AUTHENTICATIONS' config" ) connect_args["auth"] = snowflake_auth(**auth_params) + + @classmethod + def update_impersonation_config( + cls, + connect_args: dict[str, Any], + uri: str, + username: str | None, + access_token: str | None, + ) -> None: + if access_token and not security_manager.is_admin(): + connect_args.update( + { + "authenticator": "oauth", + "token": access_token, + }, + ) + + @classmethod + def get_url_for_impersonation( + cls, + url: URL, + impersonate_user: bool, + username: str | None, + access_token: str | None, # pylint: disable=unused-argument + ) -> URL: + # force OAuth2 + if impersonate_user and not security_manager.is_admin(): + url = url._replace(username="", password="", query="") + + return url + + @classmethod + def execute( + cls, + cursor: Any, + query: str, + database: Database, + **kwargs: Any, + ) -> None: + try: + cursor.execute(query) + except ProgrammingError as ex: + # refactor into base class method needs_oauth2 + if database.is_oauth2_enabled() and "User is empty" in str(ex): + cls.start_oauth2_dance(database) + raise cls.get_dbapi_mapped_exception(ex) from ex + except Exception as ex: + raise cls.get_dbapi_mapped_exception(ex) from ex + + @classmethod + def get_oauth2_token( + cls, + config: OAuth2ClientConfig, + code: str, + ) -> OAuth2TokenResponse: + timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds() + uri = config["token_request_uri"] + response = requests.post( + uri, + data={ + "code": code, + "redirect_uri": config["redirect_uri"], + "grant_type": "authorization_code", + }, + auth=HTTPBasicAuth(config["id"], config["secret"]), + timeout=timeout, + ) + return response.json() + + @classmethod + def get_oauth2_fresh_token( + cls, + config: OAuth2ClientConfig, + refresh_token: str, + ) -> OAuth2TokenResponse: + timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds() + uri = config["token_request_uri"] + response = requests.post( + uri, + data={ + "refresh_token": refresh_token, + "grant_type": "refresh_token", + }, + auth=HTTPBasicAuth(config["id"], config["secret"]), + timeout=timeout, + ) + return response.json() diff --git a/superset/models/core.py b/superset/models/core.py index 79309cdb3d..59c5cf3136 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -117,7 +117,9 @@ class ConfigurationMethod(StrEnum): DYNAMIC_FORM = "dynamic_form" -class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable=too-many-public-methods +class Database( + Model, AuditMixinNullable, ImportExportMixin +): # pylint: disable=too-many-public-methods """An ORM object that stores Database related information""" __tablename__ = "dbs" @@ -378,9 +380,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable return ( username if (username := get_username()) - else object_url.username - if self.impersonate_user - else None + else object_url.username if self.impersonate_user else None ) @contextmanager @@ -1027,7 +1027,8 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable admins to create custom OAuth2 clients from the Superset UI, and assign them to specific databases. """ - oauth2_client_info = self.get_extra().get("oauth2_client_info", {}) + config = json.loads(self.encrypted_extra or "{}") + oauth2_client_info = config.get("oauth2_client_info", {}) return bool(oauth2_client_info) or self.db_engine_spec.is_oauth2_enabled() def get_oauth2_config(self) -> OAuth2ClientConfig | None: @@ -1039,7 +1040,8 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable admins to create custom OAuth2 clients from the Superset UI, and assign them to specific databases. """ - if oauth2_client_info := self.get_extra().get("oauth2_client_info"): + config = json.loads(self.encrypted_extra or "{}") + if oauth2_client_info := config.get("oauth2_client_info"): schema = OAuth2ClientConfigSchema() client_config = schema.load(oauth2_client_info) return cast(OAuth2ClientConfig, client_config) diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 3f8c1cc737..95719a76a1 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -186,6 +186,15 @@ def get_sql_results( # pylint: disable=too-many-arguments log_params=log_params, ) except Exception as ex: # pylint: disable=broad-except + query = get_query(query_id) + database = query.database + print("\n\nBETO 456") + print(ex) + try: + database.db_engine_spec.start_oauth2_dance(database) + except OAuth2RedirectError as ex: + return handle_query_error(ex, query) + logger.debug("Query %d: %s", query_id, ex) stats_logger.incr("error_sqllab_unhandled") query = get_query(query_id) diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py index 7b7440b059..bc4805fd81 100644 --- a/superset/utils/oauth2.py +++ b/superset/utils/oauth2.py @@ -188,7 +188,7 @@ class OAuth2ClientConfigSchema(Schema): scope = fields.String(required=True) redirect_uri = fields.String( required=False, - load_default=url_for("DatabaseRestApi.oauth2", _external=True), + load_default=lambda: url_for("DatabaseRestApi.oauth2", _external=True), ) authorization_request_uri = fields.String(required=True) token_request_uri = fields.String(required=True)
