This is an automated email from the ASF dual-hosted git repository. beto pushed a commit to branch db-oauth2-client-info in repository https://gitbox.apache.org/repos/asf/superset.git
commit f31084d002d42d2a9c8765a602ec1f90de5a3a7f Author: Beto Dealmeida <[email protected]> AuthorDate: Wed May 8 10:18:54 2024 -0400 WIP --- superset/db_engine_specs/snowflake.py | 84 +++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py index 83d382cda1..a17297217e 100644 --- a/superset/db_engine_specs/snowflake.py +++ b/superset/db_engine_specs/snowflake.py @@ -14,6 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +from __future__ import annotations + import json import logging import re @@ -22,6 +25,7 @@ from re import Pattern from typing import Any, Optional, TYPE_CHECKING, TypedDict from urllib import parse +import requests from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin from cryptography.hazmat.backends import default_backend @@ -29,9 +33,11 @@ from cryptography.hazmat.primitives import serialization from flask import current_app from flask_babel import gettext as __ from marshmallow import fields, Schema +from requests.auth import HTTPBasicAuth from sqlalchemy import types from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL +from sqlalchemy.exc import ProgrammingError from superset.constants import TimeGrain, USER_AGENT from superset.databases.utils import make_url_safe @@ -39,6 +45,7 @@ from superset.db_engine_specs.base import BaseEngineSpec, BasicPropertiesType from superset.db_engine_specs.postgres import PostgresBaseEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.models.sql_lab import Query +from superset.superset_typing import OAuth2ClientConfig, OAuth2TokenResponse if TYPE_CHECKING: from superset.models.core import Database @@ -87,6 +94,11 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): supports_dynamic_schema = True supports_catalog = True + supports_oauth2 = True + oauth2_scope = "refresh_token session:role:SYSADMIN" + oauth2_authorization_request_uri = None + oauth2_token_request_uri = None + _time_grain_expressions = { None: "{col}", TimeGrain.SECOND: "DATE_TRUNC('SECOND', {col})", @@ -391,3 +403,75 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): f"must be listed in 'ALLOWED_EXTRA_AUTHENTICATIONS' config" ) connect_args["auth"] = snowflake_auth(**auth_params) + + @classmethod + def update_impersonation_config( + cls, + connect_args: dict[str, Any], + uri: str, + username: str | None, + access_token: str | None, + ) -> None: + if access_token: + connect_args.update( + { + "authenticator": "oauth", + "token": access_token, + }, + ) + + @classmethod + def execute( + cls, + cursor: Any, + query: str, + database: Database, + **kwargs: Any, + ) -> None: + try: + cursor.execute(query) + except ProgrammingError as ex: + if database.is_oauth2_enabled() and "User is empty" in str(ex): + cls.start_oauth2_dance(database) + raise cls.get_dbapi_mapped_exception(ex) from ex + except Exception as ex: + raise cls.get_dbapi_mapped_exception(ex) from ex + + @classmethod + def get_oauth2_token( + cls, + config: OAuth2ClientConfig, + code: str, + ) -> OAuth2TokenResponse: + timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds() + uri = config["token_request_uri"] + response = requests.post( + uri, + data={ + "code": code, + "redirect_uri": config["redirect_uri"], + "grant_type": "authorization_code", + }, + auth=HTTPBasicAuth(config["id"], config["secret"]), + timeout=timeout, + ) + return response.json() + + @classmethod + def get_oauth2_fresh_token( + cls, + config: OAuth2ClientConfig, + refresh_token: str, + ) -> OAuth2TokenResponse: + timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds() + uri = config["token_request_uri"] + response = requests.post( + uri, + data={ + "refresh_token": refresh_token, + "grant_type": "authorization_code", + }, + auth=HTTPBasicAuth(config["id"], config["secret"]), + timeout=timeout, + ) + return response.json()
