This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new d92af9c95c chore: simplify user impersonation (#32485)
d92af9c95c is described below

commit d92af9c95c08abb1f345360796a1a303563b902d
Author: Beto Dealmeida <[email protected]>
AuthorDate: Thu Mar 13 12:43:05 2025 -0400

    chore: simplify user impersonation (#32485)
---
 superset/db_engine_specs/base.py                   | 39 ++++++++--
 superset/db_engine_specs/drill.py                  | 47 ++++++------
 superset/db_engine_specs/gsheets.py                | 20 +++---
 superset/db_engine_specs/hive.py                   | 53 ++++----------
 superset/db_engine_specs/lib.py                    |  1 +
 superset/db_engine_specs/presto.py                 | 37 ++++------
 superset/db_engine_specs/starrocks.py              | 23 +++---
 superset/db_engine_specs/trino.py                  | 55 ++++----------
 superset/models/core.py                            | 41 ++++-------
 tests/unit_tests/db_engine_specs/test_base.py      | 84 +++++++++++++++++++++-
 tests/unit_tests/db_engine_specs/test_drill.py     | 57 +++++++++++----
 tests/unit_tests/db_engine_specs/test_gsheets.py   | 33 +++++----
 tests/unit_tests/db_engine_specs/test_starrocks.py | 26 +++----
 13 files changed, 292 insertions(+), 224 deletions(-)

diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 084891894b..79e0eb3bfd 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -22,6 +22,7 @@ import logging
 import re
 import warnings
 from datetime import datetime
+from inspect import signature
 from re import Match, Pattern
 from typing import (
     Any,
@@ -1411,11 +1412,6 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         should also have the attribute ``supports_dynamic_schema`` set to 
true, so that
         Superset knows in which schema a given query is running in order to 
enforce
         permissions (see #23385 and #23401).
-
-        Currently, changing the catalog is not supported. The method accepts a 
catalog so
-        that when catalog support is added to Superset the interface remains 
the same.
-        This is important because DB engine specs can be installed from 3rd 
party
-        packages, so we want to keep these methods as stable as possible.
         """  # noqa: E501
         return uri, {
             **connect_args,
@@ -1788,6 +1784,38 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
             ]
 
     @classmethod
+    def impersonate_user(
+        cls,
+        database: Database,
+        username: str | None,
+        user_token: str | None,
+        url: URL,
+        engine_kwargs: dict[str, Any],
+    ) -> tuple[URL, dict[str, Any]]:
+        """
+        Modify URL and/or engine kwargs to impersonate a different user.
+        """
+        # Update URL using old methods until 6.0.0.
+        url = cls.get_url_for_impersonation(url, True, username, user_token)
+
+        # Update engine kwargs using old methods. Note that #30674 modified 
the method
+        # signature, so we need to check if the method has the old signature.
+        connect_args = engine_kwargs.setdefault("connect_args", {})
+        args = [
+            connect_args,
+            url,
+            username,
+            user_token,
+        ]
+        if "database" in signature(cls.update_impersonation_config).parameters:
+            args.insert(0, database)
+
+        cls.update_impersonation_config(*args)
+
+        return url, engine_kwargs
+
+    @classmethod
+    @deprecated(deprecated_in="6.0.0")
     def get_url_for_impersonation(
         cls,
         url: URL,
@@ -1809,6 +1837,7 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         return url
 
     @classmethod
+    @deprecated(deprecated_in="6.0.0")
     def update_impersonation_config(  # pylint: disable=too-many-arguments
         cls,
         database: Database,
diff --git a/superset/db_engine_specs/drill.py 
b/superset/db_engine_specs/drill.py
index e99d4a27f4..167ed386ea 100644
--- a/superset/db_engine_specs/drill.py
+++ b/superset/db_engine_specs/drill.py
@@ -18,7 +18,7 @@
 from __future__ import annotations
 
 from datetime import datetime
-from typing import Any
+from typing import Any, TYPE_CHECKING
 from urllib import parse
 
 from sqlalchemy import types
@@ -28,6 +28,9 @@ from superset.constants import TimeGrain
 from superset.db_engine_specs.base import BaseEngineSpec
 from superset.db_engine_specs.exceptions import SupersetDBAPIProgrammingError
 
+if TYPE_CHECKING:
+    from superset.models.core import Database
+
 
 class DrillEngineSpec(BaseEngineSpec):
     """Engine spec for Apache Drill"""
@@ -99,31 +102,27 @@ class DrillEngineSpec(BaseEngineSpec):
         return parse.unquote(sqlalchemy_uri.database).replace("/", ".")
 
     @classmethod
-    def get_url_for_impersonation(
+    def impersonate_user(
         cls,
-        url: URL,
-        impersonate_user: bool,
+        database: Database,
         username: str | None,
-        access_token: str | None,
-    ) -> URL:
-        """
-        Return a modified URL with the username set.
-
-        :param url: SQLAlchemy URL object
-        :param impersonate_user: Flag indicating if impersonation is enabled
-        :param username: Effective username
-        """
-        if impersonate_user and username is not None:
-            if url.drivername == "drill+odbc":
-                url = url.update_query_dict({"DelegationUID": username})
-            elif url.drivername in ["drill+sadrill", "drill+jdbc"]:
-                url = url.update_query_dict({"impersonation_target": username})
-            else:
-                raise SupersetDBAPIProgrammingError(
-                    f"impersonation is not supported for {url.drivername}"
-                )
-
-        return url
+        user_token: str | None,
+        url: URL,
+        engine_kwargs: dict[str, Any],
+    ) -> tuple[URL, dict[str, Any]]:
+        if username is None:
+            return url, engine_kwargs
+
+        if url.drivername == "drill+odbc":
+            url = url.update_query_dict({"DelegationUID": username})
+        elif url.drivername in {"drill+sadrill", "drill+jdbc"}:
+            url = url.update_query_dict({"impersonation_target": username})
+        else:
+            raise SupersetDBAPIProgrammingError(
+                f"impersonation is not supported for {url.drivername}"
+            )
+
+        return url, engine_kwargs
 
     @classmethod
     def fetch_data(
diff --git a/superset/db_engine_specs/gsheets.py 
b/superset/db_engine_specs/gsheets.py
index 883a54abb2..647d0bc2fb 100644
--- a/superset/db_engine_specs/gsheets.py
+++ b/superset/db_engine_specs/gsheets.py
@@ -130,25 +130,23 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
     oauth2_exception = UnauthenticatedError
 
     @classmethod
-    def get_url_for_impersonation(
+    def impersonate_user(
         cls,
-        url: URL,
-        impersonate_user: bool,
+        database: Database,
         username: str | None,
-        access_token: str | None,
-    ) -> URL:
-        if not impersonate_user:
-            return url
-
+        user_token: str | None,
+        url: URL,
+        engine_kwargs: dict[str, Any],
+    ) -> tuple[URL, dict[str, Any]]:
         if username is not None:
             user = security_manager.find_user(username=username)
             if user and user.email:
                 url = url.update_query_dict({"subject": user.email})
 
-        if access_token:
-            url = url.update_query_dict({"access_token": access_token})
+        if user_token:
+            url = url.update_query_dict({"access_token": user_token})
 
-        return url
+        return url, engine_kwargs
 
     @classmethod
     def get_extra_table_metadata(
diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py
index 0a9b817804..0d2bdd3a5d 100644
--- a/superset/db_engine_specs/hive.py
+++ b/superset/db_engine_specs/hive.py
@@ -39,7 +39,6 @@ from sqlalchemy.sql.expression import ColumnClause, Select
 from superset import db
 from superset.common.db_query_status import QueryStatus
 from superset.constants import TimeGrain
-from superset.databases.utils import make_url_safe
 from superset.db_engine_specs.base import BaseEngineSpec
 from superset.db_engine_specs.presto import PrestoEngineSpec
 from superset.exceptions import SupersetException
@@ -510,54 +509,26 @@ class HiveEngineSpec(PrestoEngineSpec):
         )
 
     @classmethod
-    def get_url_for_impersonation(
-        cls,
-        url: URL,
-        impersonate_user: bool,
-        username: str | None,
-        access_token: str | None,
-    ) -> URL:
-        """
-        Return a modified URL with the username set.
-
-        :param url: SQLAlchemy URL object
-        :param impersonate_user: Flag indicating if impersonation is enabled
-        :param username: Effective username
-        """
-        # Do nothing in the URL object since instead this should modify
-        # the configuration dictionary. See get_configuration_for_impersonation
-        return url
-
-    @classmethod
-    def update_impersonation_config(  # pylint: disable=too-many-arguments
+    def impersonate_user(
         cls,
         database: Database,
-        connect_args: dict[str, Any],
-        uri: str,
         username: str | None,
-        access_token: str | None,
-    ) -> None:
-        """
-        Update a configuration dictionary
-        that can set the correct properties for impersonating users
-        :param database: the Database Object
-        :param connect_args:
-        :param uri: URI string
-        :param impersonate_user: Flag indicating if impersonation is enabled
-        :param username: Effective username
-        :return: None
-        """
-        url = make_url_safe(uri)
-        backend_name = url.get_backend_name()
+        user_token: str | None,
+        url: URL,
+        engine_kwargs: dict[str, Any],
+    ) -> tuple[URL, dict[str, Any]]:
+        if username is None:
+            return url, engine_kwargs
 
-        # Must be Hive connection, enable impersonation, and set optional param
-        # auth=LDAP|KERBEROS
-        # this will set hive.server2.proxy.user=$effective_username on 
connect_args['configuration']  # noqa: E501
-        if backend_name == "hive" and username is not None:
+        backend_name = url.get_backend_name()
+        connect_args = engine_kwargs.setdefault("connect_args", {})
+        if backend_name == "hive":
             configuration = connect_args.get("configuration", {})
             configuration["hive.server2.proxy.user"] = username
             connect_args["configuration"] = configuration
 
+        return url, engine_kwargs
+
     @staticmethod
     def execute(  # type: ignore
         cursor,
diff --git a/superset/db_engine_specs/lib.py b/superset/db_engine_specs/lib.py
index fc7a8168b2..106b9c7550 100644
--- a/superset/db_engine_specs/lib.py
+++ b/superset/db_engine_specs/lib.py
@@ -140,6 +140,7 @@ def diagnose(spec: type[BaseEngineSpec]) -> dict[str, Any]:
             "user_impersonation": (
                 has_custom_method(spec, "update_impersonation_config")
                 or has_custom_method(spec, "get_url_for_impersonation")
+                or has_custom_method(spec, "impersonate_user")
             ),
             "file_upload": spec.supports_file_upload,
             "get_extra_table_metadata": has_custom_method(
diff --git a/superset/db_engine_specs/presto.py 
b/superset/db_engine_specs/presto.py
index c3c8f61829..fa1bad8c2c 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -43,7 +43,6 @@ from sqlalchemy.sql.expression import ColumnClause, Select
 from superset import cache_manager, db, is_feature_enabled
 from superset.common.db_query_status import QueryStatus
 from superset.constants import TimeGrain
-from superset.databases.utils import make_url_safe
 from superset.db_engine_specs.base import BaseEngineSpec
 from superset.errors import SupersetErrorType
 from superset.exceptions import SupersetTemplateException
@@ -954,34 +953,26 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
         return version is not None and Version(version) >= Version("0.319")
 
     @classmethod
-    def update_impersonation_config(  # pylint: disable=too-many-arguments
+    def impersonate_user(
         cls,
         database: Database,
-        connect_args: dict[str, Any],
-        uri: str,
         username: str | None,
-        access_token: str | None,
-    ) -> None:
-        """
-        Update a configuration dictionary
-        that can set the correct properties for impersonating users
-
-        :param connect_args: the Database object
-        :param connect_args: config to be updated
-        :param uri: URI string
-        :param username: Effective username
-        :param access_token: Personal access token for OAuth2
-        :return: None
-        """
-        url = make_url_safe(uri)
-        backend_name = url.get_backend_name()
+        user_token: str | None,
+        url: URL,
+        engine_kwargs: dict[str, Any],
+    ) -> tuple[URL, dict[str, Any]]:
+        if username is None:
+            return url, engine_kwargs
+
+        url = url.set(username=username)
 
-        # Must be Presto connection, enable impersonation, and set optional 
param
-        # auth=LDAP|KERBEROS
-        # Set principal_username=$effective_username
-        if backend_name == "presto" and username is not None:
+        backend_name = url.get_backend_name()
+        connect_args = engine_kwargs.setdefault("connect_args", {})
+        if backend_name == "presto":
             connect_args["principal_username"] = username
 
+        return url, engine_kwargs
+
     @classmethod
     def get_table_names(
         cls,
diff --git a/superset/db_engine_specs/starrocks.py 
b/superset/db_engine_specs/starrocks.py
index 6f54329d6f..d3e2172f2b 100644
--- a/superset/db_engine_specs/starrocks.py
+++ b/superset/db_engine_specs/starrocks.py
@@ -204,23 +204,22 @@ class StarRocksEngineSpec(MySQLEngineSpec):
         return parse.unquote(database.split(".")[1])
 
     @classmethod
-    def get_url_for_impersonation(
+    def impersonate_user(
         cls,
+        database: Database,
+        username: str | None,
+        user_token: str | None,
         url: URL,
-        impersonate_user: bool,
-        username: Union[str, None] = None,
-        access_token: Union[str, None] = None,
-    ) -> URL:
+        engine_kwargs: dict[str, Any],
+    ) -> tuple[URL, dict[str, Any]]:
         """
-        Return a modified URL with the username set.
+        Impersonate the given user.
 
-        :param url: SQLAlchemy URL object
-        :param impersonate_user: Flag indicating if impersonation is enabled
-        :param username: Effective username
-        :param access_token: Personal access token
+        User impersonation is actually achieved via `get_prequeries`, so this 
method
+        needs to ensure that the username is not added to the URL when user
+        impersonation is enabled (the behavior of the base class).
         """
-        # Leave URL unchanged. We will impersonate with the pre-query below.
-        return url
+        return url, engine_kwargs
 
     @classmethod
     def get_prequeries(
diff --git a/superset/db_engine_specs/trino.py 
b/superset/db_engine_specs/trino.py
index 79fdef19bf..e232e08978 100644
--- a/superset/db_engine_specs/trino.py
+++ b/superset/db_engine_specs/trino.py
@@ -30,7 +30,6 @@ from sqlalchemy.exc import NoSuchTableError
 
 from superset import db
 from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY
-from superset.databases.utils import make_url_safe
 from superset.db_engine_specs.base import BaseEngineSpec, 
convert_inspector_columns
 from superset.db_engine_specs.exceptions import (
     SupersetDBAPIConnectionError,
@@ -131,55 +130,27 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
         return metadata
 
     @classmethod
-    def update_impersonation_config(  # pylint: disable=too-many-arguments
+    def impersonate_user(
         cls,
         database: Database,
-        connect_args: dict[str, Any],
-        uri: str,
         username: str | None,
-        access_token: str | None,
-    ) -> None:
-        """
-        Update a configuration dictionary
-        that can set the correct properties for impersonating users
-        :param database: the Database object
-        :param connect_args: config to be updated
-        :param uri: URI string
-        :param username: Effective username
-        :param access_token: Personal access token for OAuth2
-        :return: None
-        """
-        url = make_url_safe(uri)
-        backend_name = url.get_backend_name()
+        user_token: str | None,
+        url: URL,
+        engine_kwargs: dict[str, Any],
+    ) -> tuple[URL, dict[str, Any]]:
+        if username is None:
+            return url, engine_kwargs
 
-        # Must be Trino connection, enable impersonation, and set optional 
param
-        # auth=LDAP|KERBEROS
-        # Set principal_username=$effective_username
-        if backend_name == "trino" and username is not None:
+        backend_name = url.get_backend_name()
+        connect_args = engine_kwargs.setdefault("connect_args", {})
+        if backend_name == "trino":
             connect_args["user"] = username
-            if access_token is not None:
+            if user_token is not None:
                 http_session = requests.Session()
-                http_session.headers.update({"Authorization": f"Bearer 
{access_token}"})
+                http_session.headers.update({"Authorization": f"Bearer 
{user_token}"})
                 connect_args["http_session"] = http_session
 
-    @classmethod
-    def get_url_for_impersonation(
-        cls,
-        url: URL,
-        impersonate_user: bool,
-        username: str | None,
-        access_token: str | None,
-    ) -> URL:
-        """
-        Return a modified URL with the username set.
-
-        :param access_token: Personal access token for OAuth2
-        :param url: SQLAlchemy URL object
-        :param impersonate_user: Flag indicating if impersonation is enabled
-        :param username: Effective username
-        """
-        # Do nothing and let update_impersonation_config take care of 
impersonation
-        return url
+        return url, engine_kwargs
 
     @classmethod
     def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool:
diff --git a/superset/models/core.py b/superset/models/core.py
index 28f454295f..9378452bd8 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -479,11 +479,12 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
         self.db_engine_spec.validate_database_uri(sqlalchemy_url)
 
         extra = self.get_extra(source)
-        params = extra.get("engine_params", {})
+        engine_kwargs = extra.get("engine_params", {})
         if nullpool:
-            params["poolclass"] = NullPool
-        connect_args = params.get("connect_args", {})
+            engine_kwargs["poolclass"] = NullPool
+        connect_args = engine_kwargs.setdefault("connect_args", {})
 
+        # modify URL/args for a specific catalog/schema
         sqlalchemy_url, connect_args = 
self.db_engine_spec.adjust_engine_params(
             uri=sqlalchemy_url,
             connect_args=connect_args,
@@ -508,46 +509,32 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
             if oauth2_config and hasattr(g, "user") and hasattr(g.user, "id")
             else None
         )
-        # If using MySQL or Presto for example, will set url.username
-        # If using Hive, will not do anything yet since that relies on a
-        # configuration parameter instead.
-        sqlalchemy_url = self.db_engine_spec.get_url_for_impersonation(
-            sqlalchemy_url,
-            self.impersonate_user,
-            effective_username,
-            access_token,
-        )
-
         masked_url = self.get_password_masked_url(sqlalchemy_url)
         logger.debug("Database._get_sqla_engine(). Masked URL: %s", 
str(masked_url))
 
         if self.impersonate_user:
-            # PR #30674 changed the signature of the method to include 
database.
-            # This ensures that the change is backwards compatible
-            args = [connect_args, str(sqlalchemy_url), effective_username, 
access_token]
-            args = self.add_database_to_signature(
-                self.db_engine_spec.update_impersonation_config,
-                args,
+            sqlalchemy_url, engine_kwargs = 
self.db_engine_spec.impersonate_user(
+                self,
+                effective_username,
+                access_token,
+                sqlalchemy_url,
+                engine_kwargs,
             )
-            self.db_engine_spec.update_impersonation_config(*args)
-
-        if connect_args:
-            params["connect_args"] = connect_args
 
-        self.update_params_from_encrypted_extra(params)
+        self.update_params_from_encrypted_extra(engine_kwargs)
 
         if DB_CONNECTION_MUTATOR:
             source = source or get_query_source_from_request()
 
-            sqlalchemy_url, params = DB_CONNECTION_MUTATOR(
+            sqlalchemy_url, engine_kwargs = DB_CONNECTION_MUTATOR(
                 sqlalchemy_url,
-                params,
+                engine_kwargs,
                 effective_username,
                 security_manager,
                 source,
             )
         try:
-            return create_engine(sqlalchemy_url, **params)
+            return create_engine(sqlalchemy_url, **engine_kwargs)
         except Exception as ex:
             raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
 
diff --git a/tests/unit_tests/db_engine_specs/test_base.py 
b/tests/unit_tests/db_engine_specs/test_base.py
index bbc3bb0edc..c100007d38 100644
--- a/tests/unit_tests/db_engine_specs/test_base.py
+++ b/tests/unit_tests/db_engine_specs/test_base.py
@@ -27,7 +27,7 @@ import pytest
 from pytest_mock import MockerFixture
 from sqlalchemy import types
 from sqlalchemy.dialects import sqlite
-from sqlalchemy.engine.url import URL
+from sqlalchemy.engine.url import make_url, URL
 from sqlalchemy.sql import sqltypes
 
 from superset.sql_parse import Table
@@ -382,3 +382,85 @@ def test_unmask_encrypted_extra() -> None:
             },
         }
     )
+
+
+def test_impersonate_user_backwards_compatible(mocker: MockerFixture) -> None:
+    """
+    Test that the `impersonate_user` method calls the original methods it 
replaced.
+    """
+    from superset.db_engine_specs.base import BaseEngineSpec
+
+    database = mocker.MagicMock()
+    url = make_url("sqlite://foo.db")
+    new_url = make_url("sqlite://bar.db")
+    engine_kwargs = {"connect_args": {"user": "alice"}}
+
+    get_url_for_impersonation = mocker.patch.object(
+        BaseEngineSpec,
+        "get_url_for_impersonation",
+        return_value=new_url,
+    )
+    update_impersonation_config = mocker.patch.object(
+        BaseEngineSpec,
+        "update_impersonation_config",
+    )
+    signature = mocker.patch("superset.db_engine_specs.base.signature")
+    signature().parameters = [
+        "cls",
+        "database",
+        "connect_args",
+        "uri",
+        "username",
+        "access_token",
+    ]
+
+    BaseEngineSpec.impersonate_user(database, "alice", "SECRET", url, 
engine_kwargs)
+
+    get_url_for_impersonation.assert_called_once_with(url, True, "alice", 
"SECRET")
+    update_impersonation_config.assert_called_once_with(
+        database,
+        {"user": "alice"},
+        new_url,
+        "alice",
+        "SECRET",
+    )
+
+
+def test_impersonate_user_no_database(mocker: MockerFixture) -> None:
+    """
+    Test `impersonate_user` when `update_impersonation_config` has an old 
signature.
+    """
+    from superset.db_engine_specs.base import BaseEngineSpec
+
+    database = mocker.MagicMock()
+    url = make_url("sqlite://foo.db")
+    new_url = make_url("sqlite://bar.db")
+    engine_kwargs = {"connect_args": {"user": "alice"}}
+
+    get_url_for_impersonation = mocker.patch.object(
+        BaseEngineSpec,
+        "get_url_for_impersonation",
+        return_value=new_url,
+    )
+    update_impersonation_config = mocker.patch.object(
+        BaseEngineSpec,
+        "update_impersonation_config",
+    )
+    signature = mocker.patch("superset.db_engine_specs.base.signature")
+    signature().parameters = [
+        "cls",
+        "connect_args",
+        "uri",
+        "username",
+        "access_token",
+    ]
+
+    BaseEngineSpec.impersonate_user(database, "alice", "SECRET", url, 
engine_kwargs)
+
+    get_url_for_impersonation.assert_called_once_with(url, True, "alice", 
"SECRET")
+    update_impersonation_config.assert_called_once_with(
+        {"user": "alice"},
+        new_url,
+        "alice",
+        "SECRET",
+    )
diff --git a/tests/unit_tests/db_engine_specs/test_drill.py 
b/tests/unit_tests/db_engine_specs/test_drill.py
index eb3414ea5b..366190eaaa 100644
--- a/tests/unit_tests/db_engine_specs/test_drill.py
+++ b/tests/unit_tests/db_engine_specs/test_drill.py
@@ -20,15 +20,16 @@ from datetime import datetime
 from typing import Optional
 
 import pytest
+from pytest_mock import MockerFixture
 from sqlalchemy.engine.url import make_url
 
 from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
 from tests.unit_tests.fixtures.common import dttm  # noqa: F401
 
 
-def test_odbc_impersonation() -> None:
+def test_odbc_impersonation(mocker: MockerFixture) -> None:
     """
-    Test ``get_url_for_impersonation`` method when driver == odbc.
+    Test ``impersonate_user`` method when driver == odbc.
 
     The method adds the parameter ``DelegationUID`` to the query string.
     """
@@ -36,15 +37,23 @@ def test_odbc_impersonation() -> None:
 
     from superset.db_engine_specs.drill import DrillEngineSpec
 
+    database = mocker.MagicMock()
+
     url = URL.create("drill+odbc")
     username = "DoAsUser"
-    url = DrillEngineSpec.get_url_for_impersonation(url, True, username, None)
+    url, _ = DrillEngineSpec.impersonate_user(
+        database=database,
+        username=username,
+        user_token=None,
+        url=url,
+        engine_kwargs={},
+    )
     assert url.query["DelegationUID"] == username
 
 
-def test_jdbc_impersonation() -> None:
+def test_jdbc_impersonation(mocker: MockerFixture) -> None:
     """
-    Test ``get_url_for_impersonation`` method when driver == jdbc.
+    Test ``impersonate_user`` method when driver == jdbc.
 
     The method adds the parameter ``impersonation_target`` to the query string.
     """
@@ -52,15 +61,23 @@ def test_jdbc_impersonation() -> None:
 
     from superset.db_engine_specs.drill import DrillEngineSpec
 
+    database = mocker.MagicMock()
+
     url = URL.create("drill+jdbc")
     username = "DoAsUser"
-    url = DrillEngineSpec.get_url_for_impersonation(url, True, username, None)
+    url, _ = DrillEngineSpec.impersonate_user(
+        database=database,
+        username=username,
+        user_token=None,
+        url=url,
+        engine_kwargs={},
+    )
     assert url.query["impersonation_target"] == username
 
 
-def test_sadrill_impersonation() -> None:
+def test_sadrill_impersonation(mocker: MockerFixture) -> None:
     """
-    Test ``get_url_for_impersonation`` method when driver == sadrill.
+    Test ``impersonate_user`` method when driver == sadrill.
 
     The method adds the parameter ``impersonation_target`` to the query string.
     """
@@ -68,15 +85,23 @@ def test_sadrill_impersonation() -> None:
 
     from superset.db_engine_specs.drill import DrillEngineSpec
 
+    database = mocker.MagicMock()
+
     url = URL.create("drill+sadrill")
     username = "DoAsUser"
-    url = DrillEngineSpec.get_url_for_impersonation(url, True, username, None)
+    url, _ = DrillEngineSpec.impersonate_user(
+        database=database,
+        username=username,
+        user_token=None,
+        url=url,
+        engine_kwargs={},
+    )
     assert url.query["impersonation_target"] == username
 
 
-def test_invalid_impersonation() -> None:
+def test_invalid_impersonation(mocker: MockerFixture) -> None:
     """
-    Test ``get_url_for_impersonation`` method when driver == foobar.
+    Test ``impersonate_user`` method when driver == foobar.
 
     The method raises an exception because impersonation is not supported
     for drill+foobar.
@@ -86,11 +111,19 @@ def test_invalid_impersonation() -> None:
     from superset.db_engine_specs.drill import DrillEngineSpec
     from superset.db_engine_specs.exceptions import 
SupersetDBAPIProgrammingError
 
+    database = mocker.MagicMock()
+
     url = URL.create("drill+foobar")
     username = "DoAsUser"
 
     with pytest.raises(SupersetDBAPIProgrammingError):
-        DrillEngineSpec.get_url_for_impersonation(url, True, username, None)
+        DrillEngineSpec.impersonate_user(
+            database=database,
+            username=username,
+            user_token=None,
+            url=url,
+            engine_kwargs={},
+        )
 
 
 @pytest.mark.parametrize(
diff --git a/tests/unit_tests/db_engine_specs/test_gsheets.py 
b/tests/unit_tests/db_engine_specs/test_gsheets.py
index 744be28b50..02dc4e3803 100644
--- a/tests/unit_tests/db_engine_specs/test_gsheets.py
+++ b/tests/unit_tests/db_engine_specs/test_gsheets.py
@@ -496,9 +496,9 @@ def test_upload_existing(mocker: MockerFixture) -> None:
     )
 
 
-def test_get_url_for_impersonation_username(mocker: MockerFixture) -> None:
+def test_impersonate_user_username(mocker: MockerFixture) -> None:
     """
-    Test passing a username to `get_url_for_impersonation`.
+    Test passing a username to `impersonate_user`.
     """
     from superset.db_engine_specs.gsheets import GSheetsEngineSpec
 
@@ -508,27 +508,32 @@ def test_get_url_for_impersonation_username(mocker: 
MockerFixture) -> None:
         "superset.db_engine_specs.gsheets.security_manager.find_user",
         return_value=user,
     )
+    database = mocker.MagicMock()
 
-    assert GSheetsEngineSpec.get_url_for_impersonation(
-        url=make_url("gsheets://"),
-        impersonate_user=True,
+    assert GSheetsEngineSpec.impersonate_user(
+        database,
         username="alice",
-        access_token=None,
-    ) == make_url("gsheets://?subject=alice%40example.org")
+        user_token=None,
+        url=make_url("gsheets://"),
+        engine_kwargs={},
+    ) == (make_url("gsheets://?subject=alice%40example.org"), {})
 
 
-def test_get_url_for_impersonation_access_token() -> None:
+def test_impersonate_user_access_token(mocker: MockerFixture) -> None:
     """
-    Test passing an access token to `get_url_for_impersonation`.
+    Test passing an access token to `impersonate_user`.
     """
     from superset.db_engine_specs.gsheets import GSheetsEngineSpec
 
-    assert GSheetsEngineSpec.get_url_for_impersonation(
-        url=make_url("gsheets://"),
-        impersonate_user=True,
+    database = mocker.MagicMock()
+
+    assert GSheetsEngineSpec.impersonate_user(
+        database,
         username=None,
-        access_token="access-token",  # noqa: S106
-    ) == make_url("gsheets://?access_token=access-token")
+        user_token="access-token",  # noqa: S106
+        url=make_url("gsheets://"),
+        engine_kwargs={},
+    ) == (make_url("gsheets://?access_token=access-token"), {})
 
 
 def test_is_oauth2_enabled_no_config(mocker: MockerFixture) -> None:
diff --git a/tests/unit_tests/db_engine_specs/test_starrocks.py 
b/tests/unit_tests/db_engine_specs/test_starrocks.py
index 45a68fd62f..67016a0801 100644
--- a/tests/unit_tests/db_engine_specs/test_starrocks.py
+++ b/tests/unit_tests/db_engine_specs/test_starrocks.py
@@ -131,7 +131,7 @@ def test_get_schema_from_engine_params() -> None:
 
 def test_impersonation_username(mocker: MockerFixture) -> None:
     """
-    Test impersonation and make sure that `get_url_for_impersonation` leaves 
the URL
+    Test impersonation and make sure that `impersonate_user` leaves the URL
     unchanged and that `get_prequeries` returns the appropriate impersonation 
query.
     """
     from superset.db_engine_specs.starrocks import StarRocksEngineSpec
@@ -140,12 +140,13 @@ def test_impersonation_username(mocker: MockerFixture) -> 
None:
     database.impersonate_user = True
     database.get_effective_user.return_value = "alice"
 
-    assert StarRocksEngineSpec.get_url_for_impersonation(
-        url=make_url("starrocks://service_user@localhost:9030/hive.default"),
-        impersonate_user=True,
+    assert StarRocksEngineSpec.impersonate_user(
+        database,
         username="alice",
-        access_token=None,
-    ) == make_url("starrocks://service_user@localhost:9030/hive.default")
+        user_token=None,
+        url=make_url("starrocks://service_user@localhost:9030/hive.default"),
+        engine_kwargs={},
+    ) == (make_url("starrocks://service_user@localhost:9030/hive.default"), {})
 
     assert StarRocksEngineSpec.get_prequeries(database) == [
         'EXECUTE AS "alice" WITH NO REVERT;'
@@ -155,7 +156,7 @@ def test_impersonation_username(mocker: MockerFixture) -> 
None:
 def test_impersonation_disabled(mocker: MockerFixture) -> None:
     """
     Test that impersonation is not applied when the feature is disabled in
-    `get_url_for_impersonation` and `get_prequeries`.
+    `impersonate_user` and `get_prequeries`.
     """
     from superset.db_engine_specs.starrocks import StarRocksEngineSpec
 
@@ -163,11 +164,12 @@ def test_impersonation_disabled(mocker: MockerFixture) -> 
None:
     database.impersonate_user = False
     database.get_effective_user.return_value = "alice"
 
-    assert StarRocksEngineSpec.get_url_for_impersonation(
-        url=make_url("starrocks://service_user@localhost:9030/hive.default"),
-        impersonate_user=False,
+    assert StarRocksEngineSpec.impersonate_user(
+        database,
         username="alice",
-        access_token=None,
-    ) == make_url("starrocks://service_user@localhost:9030/hive.default")
+        user_token=None,
+        url=make_url("starrocks://service_user@localhost:9030/hive.default"),
+        engine_kwargs={},
+    ) == (make_url("starrocks://service_user@localhost:9030/hive.default"), {})
 
     assert StarRocksEngineSpec.get_prequeries(database) == []


Reply via email to