This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch db-oauth2-client-info
in repository https://gitbox.apache.org/repos/asf/superset.git

commit e4998d823a1213c556096ac0f6597fa33be70b6d
Author: Beto Dealmeida <[email protected]>
AuthorDate: Wed May 8 10:18:54 2024 -0400

    WIP
---
 superset/commands/database/update.py  |   4 ++
 superset/databases/api.py             |   1 -
 superset/db_engine_specs/base.py      |  18 +++--
 superset/db_engine_specs/snowflake.py | 128 +++++++++++++++++++++++++++++++---
 superset/models/core.py               |  14 ++--
 superset/sql_lab.py                   |   9 +++
 superset/utils/oauth2.py              |   2 +-
 7 files changed, 152 insertions(+), 24 deletions(-)

diff --git a/superset/commands/database/update.py 
b/superset/commands/database/update.py
index 5e0968954c..31c1a1e7cf 100644
--- a/superset/commands/database/update.py
+++ b/superset/commands/database/update.py
@@ -133,6 +133,10 @@ class UpdateDatabaseCommand(BaseCommand):
         try:
             schemas = database.get_all_schema_names(ssh_tunnel=ssh_tunnel)
         except Exception as ex:
+            # XXX conditional
+            if 1:
+                db.session.commit()
+                database.db_engine_spec.start_oauth2_dance(database)
             db.session.rollback()
             raise DatabaseConnectionFailedError() from ex
 
diff --git a/superset/databases/api.py b/superset/databases/api.py
index a77019123b..2fca84a357 100644
--- a/superset/databases/api.py
+++ b/superset/databases/api.py
@@ -453,7 +453,6 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
 
     @expose("/<int:pk>", methods=("PUT",))
     @protect()
-    @safe
     @statsd_metrics
     @event_logger.log_this_with_context(
         action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.put",
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 3cc1315129..65a4cfbd64 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -131,7 +131,9 @@ builtin_time_grains: dict[str | None, str] = {
 }
 
 
-class TimestampExpression(ColumnClause):  # pylint: disable=abstract-method, 
too-many-ancestors
+class TimestampExpression(
+    ColumnClause
+):  # pylint: disable=abstract-method, too-many-ancestors
     def __init__(self, expr: str, col: ColumnClause, **kwargs: Any) -> None:
         """Sqlalchemy class that can be used to render native column elements 
respecting
         engine-specific quoting rules as part of a string-based expression.
@@ -389,9 +391,9 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
     max_column_name_length: int | None = None
     try_remove_schema_from_table_name = True  # pylint: disable=invalid-name
     run_multiple_statements_as_one = False
-    custom_errors: dict[
-        Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]
-    ] = {}
+    custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, 
Any]]] = (
+        {}
+    )
 
     # Whether the engine supports file uploads
     # if True, database will be listed as option in the upload file form
@@ -1597,9 +1599,11 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
     @classmethod
     def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[Any]:
         return [
-            literal_column(query_as)
-            if (query_as := c.get("query_as"))
-            else column(c["column_name"])
+            (
+                literal_column(query_as)
+                if (query_as := c.get("query_as"))
+                else column(c["column_name"])
+            )
             for c in cols
         ]
 
diff --git a/superset/db_engine_specs/snowflake.py 
b/superset/db_engine_specs/snowflake.py
index 83d382cda1..baeb4f4762 100644
--- a/superset/db_engine_specs/snowflake.py
+++ b/superset/db_engine_specs/snowflake.py
@@ -14,6 +14,9 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
+from __future__ import annotations
+
 import json
 import logging
 import re
@@ -22,6 +25,7 @@ from re import Pattern
 from typing import Any, Optional, TYPE_CHECKING, TypedDict
 from urllib import parse
 
+import requests
 from apispec import APISpec
 from apispec.ext.marshmallow import MarshmallowPlugin
 from cryptography.hazmat.backends import default_backend
@@ -29,16 +33,20 @@ from cryptography.hazmat.primitives import serialization
 from flask import current_app
 from flask_babel import gettext as __
 from marshmallow import fields, Schema
+from requests.auth import HTTPBasicAuth
 from sqlalchemy import types
 from sqlalchemy.engine.reflection import Inspector
 from sqlalchemy.engine.url import URL
+from sqlalchemy.exc import ProgrammingError
 
+from superset import security_manager
 from superset.constants import TimeGrain, USER_AGENT
 from superset.databases.utils import make_url_safe
 from superset.db_engine_specs.base import BaseEngineSpec, BasicPropertiesType
 from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
 from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
 from superset.models.sql_lab import Query
+from superset.superset_typing import OAuth2ClientConfig, OAuth2TokenResponse
 
 if TYPE_CHECKING:
     from superset.models.core import Database
@@ -57,12 +65,12 @@ logger = logging.getLogger(__name__)
 
 
 class SnowflakeParametersSchema(Schema):
-    username = fields.Str(required=True)
-    password = fields.Str(required=True)
+    username = fields.Str(required=False)
+    password = fields.Str(required=False)
     account = fields.Str(required=True)
     database = fields.Str(required=True)
-    role = fields.Str(required=True)
-    warehouse = fields.Str(required=True)
+    role = fields.Str(required=False)
+    warehouse = fields.Str(required=False)
 
 
 class SnowflakeParametersType(TypedDict):
@@ -87,6 +95,11 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
     supports_dynamic_schema = True
     supports_catalog = True
 
+    supports_oauth2 = True
+    oauth2_scope = "refresh_token session:role:SYSADMIN"
+    oauth2_authorization_request_uri = None
+    oauth2_token_request_uri = None
+
     _time_grain_expressions = {
         None: "{col}",
         TimeGrain.SECOND: "DATE_TRUNC('SECOND', {col})",
@@ -123,17 +136,29 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
         ),
     }
 
-    @staticmethod
-    def get_extra_params(database: "Database") -> dict[str, Any]:
+    @classmethod
+    def get_extra_params(cls, database: "Database") -> dict[str, Any]:
         """
         Add a user agent to be used in the requests.
         """
         extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database)
         engine_params: dict[str, Any] = extra.setdefault("engine_params", {})
         connect_args: dict[str, Any] = 
engine_params.setdefault("connect_args", {})
-
         connect_args.setdefault("application", USER_AGENT)
 
+        # populate OAuth2 URLs if not set, since they can be inferred from the 
account
+        if oauth2_client_info := extra.get("oauth2_client_info"):
+            account = database.url_object.host
+            oauth2_client_info.setdefault(
+                "authorization_request_uri",
+                f"https://{account}.snowflakecomputing.com/oauth/authorize";,
+            )
+            oauth2_client_info.setdefault(
+                "token_request_uri",
+                
f"https://{account}.snowflakecomputing.com/oauth/token-request";,
+            )
+            oauth2_client_info.setdefault("scope", cls.oauth2_scope)
+
         return extra
 
     @classmethod
@@ -303,11 +328,9 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
     ) -> list[SupersetError]:
         errors: list[SupersetError] = []
         required = {
-            "warehouse",
             "username",
             "database",
             "account",
-            "role",
             "password",
         }
         parameters = properties.get("parameters", {})
@@ -391,3 +414,90 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
                     f"must be listed in 'ALLOWED_EXTRA_AUTHENTICATIONS' config"
                 )
             connect_args["auth"] = snowflake_auth(**auth_params)
+
+    @classmethod
+    def update_impersonation_config(
+        cls,
+        connect_args: dict[str, Any],
+        uri: str,
+        username: str | None,
+        access_token: str | None,
+    ) -> None:
+        if access_token and not security_manager.is_admin():
+            connect_args.update(
+                {
+                    "authenticator": "oauth",
+                    "token": access_token,
+                },
+            )
+
+    @classmethod
+    def get_url_for_impersonation(
+        cls,
+        url: URL,
+        impersonate_user: bool,
+        username: str | None,
+        access_token: str | None,  # pylint: disable=unused-argument
+    ) -> URL:
+        # force OAuth2
+        if impersonate_user and not security_manager.is_admin():
+            url = url._replace(username="", password="", query="")
+
+        return url
+
+    @classmethod
+    def execute(
+        cls,
+        cursor: Any,
+        query: str,
+        database: Database,
+        **kwargs: Any,
+    ) -> None:
+        try:
+            cursor.execute(query)
+        except ProgrammingError as ex:
+            # refactor into base class method needs_oauth2
+            if database.is_oauth2_enabled() and "User is empty" in str(ex):
+                cls.start_oauth2_dance(database)
+            raise cls.get_dbapi_mapped_exception(ex) from ex
+        except Exception as ex:
+            raise cls.get_dbapi_mapped_exception(ex) from ex
+
+    @classmethod
+    def get_oauth2_token(
+        cls,
+        config: OAuth2ClientConfig,
+        code: str,
+    ) -> OAuth2TokenResponse:
+        timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
+        uri = config["token_request_uri"]
+        response = requests.post(
+            uri,
+            data={
+                "code": code,
+                "redirect_uri": config["redirect_uri"],
+                "grant_type": "authorization_code",
+            },
+            auth=HTTPBasicAuth(config["id"], config["secret"]),
+            timeout=timeout,
+        )
+        return response.json()
+
+    @classmethod
+    def get_oauth2_fresh_token(
+        cls,
+        config: OAuth2ClientConfig,
+        refresh_token: str,
+    ) -> OAuth2TokenResponse:
+        timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
+        uri = config["token_request_uri"]
+        response = requests.post(
+            uri,
+            data={
+                "refresh_token": refresh_token,
+                "grant_type": "refresh_token",
+            },
+            auth=HTTPBasicAuth(config["id"], config["secret"]),
+            timeout=timeout,
+        )
+        return response.json()
diff --git a/superset/models/core.py b/superset/models/core.py
index 79309cdb3d..59c5cf3136 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -117,7 +117,9 @@ class ConfigurationMethod(StrEnum):
     DYNAMIC_FORM = "dynamic_form"
 
 
-class Database(Model, AuditMixinNullable, ImportExportMixin):  # pylint: 
disable=too-many-public-methods
+class Database(
+    Model, AuditMixinNullable, ImportExportMixin
+):  # pylint: disable=too-many-public-methods
     """An ORM object that stores Database related information"""
 
     __tablename__ = "dbs"
@@ -378,9 +380,7 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
         return (
             username
             if (username := get_username())
-            else object_url.username
-            if self.impersonate_user
-            else None
+            else object_url.username if self.impersonate_user else None
         )
 
     @contextmanager
@@ -1027,7 +1027,8 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
         admins to create custom OAuth2 clients from the Superset UI, and 
assign them to
         specific databases.
         """
-        oauth2_client_info = self.get_extra().get("oauth2_client_info", {})
+        config = json.loads(self.encrypted_extra or "{}")
+        oauth2_client_info = config.get("oauth2_client_info", {})
         return bool(oauth2_client_info) or 
self.db_engine_spec.is_oauth2_enabled()
 
     def get_oauth2_config(self) -> OAuth2ClientConfig | None:
@@ -1039,7 +1040,8 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
         admins to create custom OAuth2 clients from the Superset UI, and 
assign them to
         specific databases.
         """
-        if oauth2_client_info := self.get_extra().get("oauth2_client_info"):
+        config = json.loads(self.encrypted_extra or "{}")
+        if oauth2_client_info := config.get("oauth2_client_info"):
             schema = OAuth2ClientConfigSchema()
             client_config = schema.load(oauth2_client_info)
             return cast(OAuth2ClientConfig, client_config)
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index 3f8c1cc737..95719a76a1 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -186,6 +186,15 @@ def get_sql_results(  # pylint: disable=too-many-arguments
                 log_params=log_params,
             )
         except Exception as ex:  # pylint: disable=broad-except
+            query = get_query(query_id)
+            database = query.database
+            print("\n\nBETO 456")
+            print(ex)
+            try:
+                database.db_engine_spec.start_oauth2_dance(database)
+            except OAuth2RedirectError as ex:
+                return handle_query_error(ex, query)
+
             logger.debug("Query %d: %s", query_id, ex)
             stats_logger.incr("error_sqllab_unhandled")
             query = get_query(query_id)
diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py
index 7b7440b059..bc4805fd81 100644
--- a/superset/utils/oauth2.py
+++ b/superset/utils/oauth2.py
@@ -188,7 +188,7 @@ class OAuth2ClientConfigSchema(Schema):
     scope = fields.String(required=True)
     redirect_uri = fields.String(
         required=False,
-        load_default=url_for("DatabaseRestApi.oauth2", _external=True),
+        load_default=lambda: url_for("DatabaseRestApi.oauth2", _external=True),
     )
     authorization_request_uri = fields.String(required=True)
     token_request_uri = fields.String(required=True)

Reply via email to