This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch db-oauth2-client-info
in repository https://gitbox.apache.org/repos/asf/superset.git

commit f31084d002d42d2a9c8765a602ec1f90de5a3a7f
Author: Beto Dealmeida <[email protected]>
AuthorDate: Wed May 8 10:18:54 2024 -0400

    WIP
---
 superset/db_engine_specs/snowflake.py | 84 +++++++++++++++++++++++++++++++++++
 1 file changed, 84 insertions(+)

diff --git a/superset/db_engine_specs/snowflake.py 
b/superset/db_engine_specs/snowflake.py
index 83d382cda1..a17297217e 100644
--- a/superset/db_engine_specs/snowflake.py
+++ b/superset/db_engine_specs/snowflake.py
@@ -14,6 +14,9 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
+from __future__ import annotations
+
 import json
 import logging
 import re
@@ -22,6 +25,7 @@ from re import Pattern
 from typing import Any, Optional, TYPE_CHECKING, TypedDict
 from urllib import parse
 
+import requests
 from apispec import APISpec
 from apispec.ext.marshmallow import MarshmallowPlugin
 from cryptography.hazmat.backends import default_backend
@@ -29,9 +33,11 @@ from cryptography.hazmat.primitives import serialization
 from flask import current_app
 from flask_babel import gettext as __
 from marshmallow import fields, Schema
+from requests.auth import HTTPBasicAuth
 from sqlalchemy import types
 from sqlalchemy.engine.reflection import Inspector
 from sqlalchemy.engine.url import URL
+from sqlalchemy.exc import ProgrammingError
 
 from superset.constants import TimeGrain, USER_AGENT
 from superset.databases.utils import make_url_safe
@@ -39,6 +45,7 @@ from superset.db_engine_specs.base import BaseEngineSpec, 
BasicPropertiesType
 from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
 from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
 from superset.models.sql_lab import Query
+from superset.superset_typing import OAuth2ClientConfig, OAuth2TokenResponse
 
 if TYPE_CHECKING:
     from superset.models.core import Database
@@ -87,6 +94,11 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
     supports_dynamic_schema = True
     supports_catalog = True
 
+    supports_oauth2 = True
+    oauth2_scope = "refresh_token session:role:SYSADMIN"
+    oauth2_authorization_request_uri = None
+    oauth2_token_request_uri = None
+
     _time_grain_expressions = {
         None: "{col}",
         TimeGrain.SECOND: "DATE_TRUNC('SECOND', {col})",
@@ -391,3 +403,75 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
                     f"must be listed in 'ALLOWED_EXTRA_AUTHENTICATIONS' config"
                 )
             connect_args["auth"] = snowflake_auth(**auth_params)
+
+        @classmethod
+        def update_impersonation_config(
+            cls,
+            connect_args: dict[str, Any],
+            uri: str,
+            username: str | None,
+            access_token: str | None,
+        ) -> None:
+            if access_token:
+                connect_args.update(
+                    {
+                        "authenticator": "oauth",
+                        "token": access_token,
+                    },
+                )
+
+        @classmethod
+        def execute(
+            cls,
+            cursor: Any,
+            query: str,
+            database: Database,
+            **kwargs: Any,
+        ) -> None:
+            try:
+                cursor.execute(query)
+            except ProgrammingError as ex:
+                if database.is_oauth2_enabled() and "User is empty" in str(ex):
+                    cls.start_oauth2_dance(database)
+                raise cls.get_dbapi_mapped_exception(ex) from ex
+            except Exception as ex:
+                raise cls.get_dbapi_mapped_exception(ex) from ex
+
+        @classmethod
+        def get_oauth2_token(
+            cls,
+            config: OAuth2ClientConfig,
+            code: str,
+        ) -> OAuth2TokenResponse:
+            timeout = 
current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
+            uri = config["token_request_uri"]
+            response = requests.post(
+                uri,
+                data={
+                    "code": code,
+                    "redirect_uri": config["redirect_uri"],
+                    "grant_type": "authorization_code",
+                },
+                auth=HTTPBasicAuth(config["id"], config["secret"]),
+                timeout=timeout,
+            )
+            return response.json()
+
+        @classmethod
+        def get_oauth2_fresh_token(
+            cls,
+            config: OAuth2ClientConfig,
+            refresh_token: str,
+        ) -> OAuth2TokenResponse:
+            timeout = 
current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
+            uri = config["token_request_uri"]
+            response = requests.post(
+                uri,
+                data={
+                    "refresh_token": refresh_token,
+                    "grant_type": "authorization_code",
+                },
+                auth=HTTPBasicAuth(config["id"], config["secret"]),
+                timeout=timeout,
+            )
+            return response.json()

Reply via email to